diff --git a/tools/bazel.rc b/.bazelrc similarity index 80% rename from tools/bazel.rc rename to .bazelrc index 601e07ffddec9f2b11518b4b2e82bea4fc2201cc..d5d20309df82498a552df759e3d200a914a4cfb7 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -24,12 +24,13 @@ build --define framework_shared_object=true # Please note that MKL on MacOS or windows is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. -build:mkl --define=using_mkl=true +build:mkl --define=build_with_mkl=true --define=enable_mkl=true build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. -build:mkl_open_source_only --define=using_mkl_dnn_only=true +build:mkl_open_source_only --define=build_with_mkl_dnn_only=true +build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -42,6 +43,9 @@ build:download_clang_use_lld --linkopt='-fuse-ld=lld' build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true + build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true @@ -57,6 +61,11 @@ build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fn build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true +# Options extracted from configure script +build:gdr --define=with_gdr_support=true +build:ngraph --define=with_ngraph_support=true +build:verbs --define=with_verbs_support=true + build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true build --define=grpc_no_ares=true @@ -65,5 +74,15 @@ build --spawn_strategy=standalone build --genrule_strategy=standalone build -c opt +# Other build flags. +build --define=grpc_no_ares=true + # Modular TF build options build:dynamic_kernels --define=dynamic_loaded_kernels=true + +# Default paths for TF_SYSTEM_LIBS +build --define=PREFIX=/usr +build --define=LIBDIR=$(PREFIX)/lib +build --define=INCLUDEDIR=$(PREFIX)/include + +# Do not commit the tf_configure.bazelrc line diff --git a/.gitignore b/.gitignore index 1ef4c297ee4f369775c13b32a46a55887de719e7..cb65f447d4a551266e237714a16d71b58bcfc51d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/CODEOWNERS b/CODEOWNERS index 78f80c8d718983f00fd5010c3fe5d561124d3714..94cc865479cd6ab5cdb589490d3a2d650f06b160 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,6 +2,7 @@ /tenosrflow/core/debug @caisq /tensorflow/core/platform/windows/ @mrry +/tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar /tensorflow/java/ @asimshankar /tensorflow/python/debug @caisq @@ -30,14 +31,16 @@ /tensorflow/contrib/gan/ @joel-shor /tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ +/tensorflow/contrib/hadoop @yongtang /tensorflow/contrib/hvx/ @satok16 /tensorflow/contrib/integrate/ @shoyer +/tensorflow/contrib/kafka @yongtang /tensorflow/contrib/kernel_methods/ @petrosmol +/tensorflow/contrib/kinesis @yongtang /tensorflow/contrib/ios_examples/ @petewarden /tensorflow/contrib/labeled_tensor/ @shoyer /tensorflow/contrib/layers/ @fchollet @martinwicke /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -/tensorflow/contrib/linalg/ @langmore /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis /tensorflow/contrib/lookup/ @ysuematsu @andreasst /tensorflow/contrib/losses/ @alextp @ispirmustafa diff --git a/README.md b/README.md index e3092e551e32d7f01e9bebd65323d1b5691f0269..57efb876c9afaf9fe76c4ced4e6a1572e9241edf 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* +*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* People who are a little more adventurous can also try our nightly binaries: @@ -48,15 +48,12 @@ $ python ``` ```python >>> import tensorflow as tf +>>> tf.enable_eager_execution() +>>> tf.add(1, 2) +3 >>> hello = tf.constant('Hello, TensorFlow!') ->>> sess = tf.Session() ->>> sess.run(hello) +>>> hello.numpy() 'Hello, TensorFlow!' ->>> a = tf.constant(10) ->>> b = tf.constant(32) ->>> sess.run(a + b) -42 ->>> sess.close() ``` Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). @@ -106,13 +103,13 @@ The TensorFlow project strives to abide by generally accepted best practices in ## For more information +* [TensorFlow Website](https://www.tensorflow.org) +* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow Blog](https://medium.com/tensorflow) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [TensorFlow Twitter](https://twitter.com/tensorflow) -* [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) diff --git a/RELEASE.md b/RELEASE.md index 763ef3b279dde209ed387534032deae40a33a9e4..20e1d9217b7684e696d0abf427eef9ab9548d1b7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,86 @@ +# Release 1.11.0 + +## Major Features and Improvements + +* Nvidia GPU: + * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) +* Google Cloud TPU: + * Experimental tf.data integration for Keras on Google Cloud TPUs. + * Experimental / preview support for eager execution on Google Cloud TPUs. +* DistributionStrategy: + * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs. + * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details. +* Add C, C++, and Python functions for querying kernels + +## Breaking Changes + +* Keras: + * The default values for tf.keras `RandomUniform`, `RandomNormal`, and `TruncatedNormal` initializers have been changed to match those in external Keras. + * Breaking change: `model.get_config()` on a Sequential model now returns a config dictionary (consistent with other Model instances) instead of a list of configs for the underlying layers. + +## Bug Fixes and Other Changes + +* C++: + * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure. +* tf.data: + * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. + * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files. + * Renamed BigTable class to BigtableTable for clarity + * Document use of the Cloud Bigtable API + * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. + * Generalization of `tf.contrib.data.sliding_window_batch`. +* INC: + * Runtime improvements to triangular solve. +* `tf.contrib`: + * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`. + * Add documentation clarifying the differences between tf.fill and tf.constant. + * Add experimental IndexedDatasets. + * Add selective registration target using the lite proto runtime. + * Add simple Tensor and DataType classes to TensorFlow Lite Java + * Add support for bitcasting to/from uint32 and uint64. + * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator). + * Adds leaf index modes as an argument. + * Allow a different output shape from the input in tf.contrib.image.transform. + * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells. + * Deprecate self.test_session() in favor of self.session() or self.cached_session(). + * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon) + * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. + * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. + * Fix toco compilation/execution on Windows + * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. + * It is now safe to call any of the C API's TF_Delete\* functions on nullptr + * Log some errors on Android to logcat + * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. + * Optional bucket location check for the GCS Filesystem. + * Performance enhancements for StringSplitOp & StringSplitV2Op. + * Performance improvements for regex replace operations. + * TFRecordWriter now raises an error if .write() fails. + * TPU: More helpful error messages in TPUClusterResolvers. + * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time. + * The protocol used for Estimator training is now configurable in RunConfig. + * Triangular solve performance improvements. + * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method. + * Update initialization of variables in Keras. + * Updates to "constrained_optimization" in tensorflow/contrib. + * boosted trees: adding pruning mode + * tf.train.Checkpoint does not delete old checkpoints by default. + * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Aapeli, adoda, Ag Ramesh, Amogh Mannekote, Andrew Gibiansky, Andy Craze, Anirudh Koul, Aurelien Geron, Avijit, Avijit-Nervana, Ben, Benjamin H. Myara, bhack, Brett Koonce, Cao Zongyan, cbockman, cheerss, Chikanaga Tomoyuki, Clayne Robison, cosine0, Cui Wei, Dan J, David, David Norman, Dmitry Klimenkov, Eliel Hojman, Florian Courtial, fo40225, formath, Geoffrey Irving, gracehoney, Grzegorz Pawelczak, Guoliang Hua, Guozhong Zhuang, Herman Zvonimir DošIlović, HuiyangFei, Jacker, Jan HüNnemeyer, Jason Taylor, Jason Zaman, Jesse, Jiang,Zhoulong, Jiawei Zhang, Jie, Joe Yearsley, Johannes Schmitz, Jon Perl, Jon Triebenbach, Jonathan, Jonathan Hseu, Jongmin Park, Justin Shenk, karl@kubx.ca, Kate Hodesdon, Kb Sriram, Keishi Hattori, Kenneth Blomqvist, Koan-Sin Tan, Li Liangbin, Li, Yiqiang, Loo Rong Jie, Madiyar, Mahmoud Abuzaina, Mark Ryan, Matt Dodge, mbhuiyan, melvinljy96, Miguel Mota, Nafis Sadat, Nathan Luehr, naurril, Nehal J Wani, Niall Moran, Niranjan Hasabnis, Nishidha Panpaliya, npow, olicht, Pei Zhang, Peng Wang (Simpeng), Peng Yu, Philipp Jund, Pradeep Banavara, Pratik Kalshetti, qwertWZ, Rakesh Chada, Randy West, Ray Kim, Rholais Lii, Robin Richtsfeld, Rodrigo Silveira, Ruizhi, Santosh Kumar, Seb Bro, Sergei Lebedev, sfujiwara, Shaba Abhiram, Shashi, SneakyFish5, Soila Kavulya, Stefan Dyulgerov, Steven Winston, Sunitha Kambhampati, Surry Shome, Taehoon Lee, Thor Johnsen, Tristan Rice, TShapinsky, tucan, tucan9389, Vicente Reyes, Vilmar-Hillow, Vitaly Lavrukhin, wangershi, weidan.kong, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Wim Glenn, XFeiF, Yan Facai (颜发才), Yanbo Liang, Yong Tang, Yoshihiro Yamazaki, Yuan (Terry) Tang, Yuan, Man, zhaoyongke, ÁRon +Ricardo Perez-Lopez, 张天启, 张晓飞 + + +# Release 1.10.1 +## Bug Fixes and Other Changes + +* `tf.keras`: + * Fixing keras on Cloud TPUs. No new binaries will be built for Windows. + + # Release 1.10.0 ## Major Features And Improvements @@ -11,7 +94,7 @@ ## Breaking Changes -* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source). * Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. ## Bug Fixes and Other Changes diff --git a/configure.py b/configure.py index 361bd4764dc5c1900be7378f51c00aedf6f2ce41..89dc79b6b6bb168339d05182fd9da47dfc90ce54 100644 --- a/configure.py +++ b/configure.py @@ -35,13 +35,11 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '2.2' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -49,10 +47,18 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 -_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' -_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) -_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') +_TF_WORKSPACE_ROOT = '' +_TF_BAZELRC = '' + +NCCL_LIB_PATHS = [ + 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' +] + +if platform.machine() == 'ppc64le': + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' +else: + _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() class UserInputError(Exception): @@ -153,14 +159,18 @@ def get_python_path(environ_cp, python_bin_path): if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') try: - library_paths = run_shell( - [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') + library_paths = run_shell([ + python_bin_path, '-c', + 'import site; print("\\n".join(site.getsitepackages()))' + ]).split('\n') except subprocess.CalledProcessError: - library_paths = [run_shell( - [python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())'])] + library_paths = [ + run_shell([ + python_bin_path, '-c', + 'from distutils.sysconfig import get_python_lib;' + 'print(get_python_lib())' + ]) + ] all_paths = set(python_paths + library_paths) @@ -187,8 +197,7 @@ def setup_python(environ_cp): environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path, default_python_bin_path) # Check if the path is valid - if os.path.isfile(python_bin_path) and os.access( - python_bin_path, os.X_OK): + if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): print('Invalid python path: %s cannot be found.' % python_bin_path) @@ -217,7 +226,7 @@ def setup_python(environ_cp): python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path - python_major_version = get_python_major_version(python_bin_path) + _ = get_python_major_version(python_bin_path) # Convert python path to Windows style before writing into bazel.rc if is_windows() or is_cygwin(): @@ -230,15 +239,16 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open(os.path.join( - _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: + with open( + os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), + 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(workspace_path): +def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(workspace_path, '.bazelrc') + bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') data = [] if os.path.exists(bazelrc_path): @@ -249,20 +259,15 @@ def reset_tf_configure_bazelrc(workspace_path): if _TF_BAZELRC_FILENAME in l: continue f.write('%s\n' % l) - if is_windows(): - tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") - else: - tf_bazelrc_path = _TF_BAZELRC - f.write('import %s\n' % tf_bazelrc_path) - + f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. These files could interfere with Bazel parsing. """ - makefile_download_dir = os.path.join( - _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') + makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow', + 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -330,9 +335,8 @@ def get_var(environ_cp, 'Environment variable %s must be set as a boolean indicator.\n' 'The following are accepted as TRUE : %s.\n' 'The following are accepted as FALSE: %s.\n' - 'Current value is %s.' % ( - var_name, ', '.join(true_strings), ', '.join(false_strings), - var)) + 'Current value is %s.' % (var_name, ', '.join(true_strings), + ', '.join(false_strings), var)) while var is None: user_input_origin = get_input(question) @@ -355,8 +359,12 @@ def get_var(environ_cp, return var -def set_build_var(environ_cp, var_name, query_item, option_name, - enabled_by_default, bazel_config_name=None): +def set_build_var(environ_cp, + var_name, + query_item, + option_name, + enabled_by_default, + bazel_config_name=None): """Set if query_item will be enabled for the build. Ask user if query_item will be enabled. Default is used if no input is given. @@ -375,12 +383,14 @@ def set_build_var(environ_cp, var_name, query_item, option_name, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var if var == '1': - write_to_bazelrc('build --define %s=true' % option_name) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build --config=%s' % bazel_config_name) elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. - write_to_bazelrc('build:%s --define %s=true' - % (bazel_config_name, option_name)) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) def set_action_env_var(environ_cp, @@ -447,7 +457,8 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) + curr_version = run_shell( + ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -499,6 +510,7 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') + def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -581,16 +593,14 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) -def prompt_loop_or_load_from_env( - environ_cp, - var_name, - var_default, - ask_for_var, - check_success, - error_msg, - suppress_default_error=False, - n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS -): +def prompt_loop_or_load_from_env(environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS): """Loop over user prompts for an ENV param until receiving a valid response. For the env param var_name, read from the environment or verify user input @@ -629,9 +639,7 @@ def prompt_loop_or_load_from_env( ) for _ in range(n_ask_attempts): - val = get_from_env_or_user_or_default(environ_cp, - var_name, - full_query, + val = get_from_env_or_user_or_default(environ_cp, var_name, full_query, default) if check_success(val): break @@ -639,9 +647,9 @@ def prompt_loop_or_load_from_env( print(error_msg % val) environ_cp[var_name] = '' else: - raise UserInputError('Invalid %s setting was provided %d times in a row. ' - 'Assuming to be a scripting mistake.' % - (var_name, n_ask_attempts)) + raise UserInputError( + 'Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts)) environ_cp[var_name] = val return val @@ -650,8 +658,8 @@ def prompt_loop_or_load_from_env( def create_android_ndk_rule(environ_cp): """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" if is_windows() or is_cygwin(): - default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % - environ_cp['APPDATA']) + default_ndk_path = cygpath( + '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA']) elif is_macos(): default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] else: @@ -668,8 +676,7 @@ def create_android_ndk_rule(environ_cp): ask_for_var='Please specify the home path of the Android NDK to use.', check_success=valid_ndk_path, error_msg=('The path %s or its child file "source.properties" ' - 'does not exist.') - ) + 'does not exist.')) write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', check_ndk_level(android_ndk_home_path)) @@ -703,9 +710,9 @@ def create_android_sdk_rule(environ_cp): api_levels = [x.replace('android-', '') for x in api_levels] def valid_api_level(api_level): - return os.path.exists(os.path.join(android_sdk_home_path, - 'platforms', - 'android-' + api_level)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'platforms', + 'android-' + api_level)) android_api_level = prompt_loop_or_load_from_env( environ_cp, @@ -720,9 +727,8 @@ def create_android_sdk_rule(environ_cp): versions = sorted(os.listdir(build_tools)) def valid_build_tools(version): - return os.path.exists(os.path.join(android_sdk_home_path, - 'build-tools', - version)) + return os.path.exists( + os.path.join(android_sdk_home_path, 'build-tools', version)) android_build_tools_version = prompt_loop_or_load_from_env( environ_cp, @@ -736,10 +742,8 @@ def create_android_sdk_rule(environ_cp): write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', android_build_tools_version) - write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', - android_api_level) - write_action_env_to_bazelrc('ANDROID_SDK_HOME', - android_sdk_home_path) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -798,6 +802,7 @@ def reformat_version_sequence(version_str, sequence_count): Args: version_str: String, the version string. sequence_count: int, an integer. + Returns: string, reformatted version string. """ @@ -841,18 +846,25 @@ def set_tf_cuda_version(environ_cp): if is_windows(): cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version) - for x in ['lib64', 'lib/x86_64-linux-gnu']] + cuda_rt_lib_paths = [ + '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [ + 'lib64', + 'lib/powerpc64le-linux-gnu', + 'lib/x86_64-linux-gnu', + ] + ] elif is_macos(): cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] + cuda_toolkit_paths_full = [ + os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths + ] if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_path_full)) + (tf_cuda_version, cuda_toolkit_paths_full)) environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' @@ -872,7 +884,7 @@ def set_tf_cudnn_version(environ_cp): """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION + '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( @@ -919,8 +931,8 @@ def set_tf_cudnn_version(environ_cp): cudnn_path_from_ldconfig) if cudnn_path_from_ldconfig: cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, - tf_cudnn_version)): + if os.path.exists( + '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) break @@ -1029,7 +1041,7 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) - if len(matches.groups()) == 0: + if not matches.groups(): continue ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 @@ -1085,7 +1097,7 @@ def set_tf_tensorrt_install_path(environ_cp): def set_tf_nccl_install_path(environ_cp): - """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION. Args: environ_cp: copy of the os.environ. @@ -1098,59 +1110,119 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. If NCCL %s is not ' - 'installed, then you can use version 1.3 that can be fetched ' - 'automatically but it may have worse performance with multiple GPUs. ' - '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) + 'Please specify the locally installed NCCL version you want to use. ' + '[Default is to use https://github.com/nvidia/nccl]: ') for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( - environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) - tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '') + + if not tf_nccl_version: + break # No need to get install path, building the open source code. - if tf_nccl_version == '1': - break # No need to get install path, NCCL 1 is a GitHub repo. + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) - # TODO(csigg): Look with ldconfig first if we can find the library in paths + # Look with ldconfig first if we can find the library in paths # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding # include directory. This is where the NCCL .deb packages install them. - # Then ask the user if we should use that. Instead of a single - # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to - # nccl_configure.bzl - default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') - ask_nccl_path = (r'Please specify the location where NCCL %s library is ' - 'installed. Refer to README.md for more details. [Default ' - 'is %s]:') % (tf_nccl_version, default_nccl_path) - nccl_install_path = get_from_env_or_user_or_default( - environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) - if is_windows() or is_cygwin(): - nccl_install_path = cygpath(nccl_install_path) - - if is_windows(): - nccl_lib_path = 'lib/x64/nccl.lib' - elif is_linux(): - nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version - elif is_macos(): - nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version - - nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) - nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): - # Set NCCL_INSTALL_PATH - environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path - write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) - break - - # Reset and Retry - print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + # First check to see if NCCL is in the ldconfig. + # If its found, use that location. + if is_linux(): + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) + nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)', + nccl2_path_from_ldconfig) + if nccl2_path_from_ldconfig: + nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1) + if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)): + nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig) + print('NCCL libraries found in ' + nccl2_path_from_ldconfig) + + # Check if this is the main system lib location + if re.search('.*linux-gnu', nccl_install_path): + trunc_nccl_install_path = '/usr' + print('This looks like a system path.') + else: + trunc_nccl_install_path = nccl_install_path + '/..' + + # Look for header + nccl_hdr_path = trunc_nccl_install_path + '/include' + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_hdr_path + '/nccl.h'): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path + write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path) + break + else: + print( + 'The header for NCCL2 cannot be found. Please install the libnccl-dev package.' + ) + else: + print('NCCL2 is listed by ldconfig but the library is not found. ' + 'Your ldconfig is out of date. Please run sudo ldconfig.') + else: + # NCCL is not found in ldconfig. Ask the user for the location. + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = ( + r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath( + os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version + nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename) + if not os.path.exists(nccl_lpath): + for relative_path in NCCL_LIB_PATHS: + path = '%s/%s%s' % (nccl_install_path, relative_path, + nccl_lib_filename) + if os.path.exists(path): + print('NCCL found at ' + path) + nccl_lib_path = path + break + else: + nccl_lib_path = nccl_lpath + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join( + os.path.dirname(nccl_lib_path), '../include/nccl.h') + print('Assuming NCCL header path is ' + nccl_hdr_path) + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path) + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', + os.path.dirname(nccl_lib_path)) + + # Set NCCL_HDR_PATH + environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path) + write_action_env_to_bazelrc('NCCL_HDR_PATH', + os.path.dirname(nccl_hdr_path)) + break + + # Reset and Retry + print( + 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, nccl_hdr_path)) - environ_cp['TF_NCCL_VERSION'] = '' + environ_cp['TF_NCCL_VERSION'] = '' else: raise UserInputError('Invalid TF_NCCL setting was provided %d ' 'times in a row. Assuming to be a scripting mistake.' % @@ -1160,12 +1232,12 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) - def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. Args: environ_cp: copy of the os.environ. + Returns: string of native cuda compute capabilities, separated by comma. """ @@ -1290,8 +1362,7 @@ def set_computecpp_toolkit_path(environ_cp): else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(toolkit_path, - sycl_rt_lib_path) + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) exists = os.path.exists(sycl_rt_lib_path_full) if not exists: print('Invalid SYCL %s library path. %s cannot be found' % @@ -1319,8 +1390,8 @@ def set_trisycl_include_dir(environ_cp): ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 'include directory. (Use --config=sycl_trisycl ' 'when building with Bazel) ' - '[Default is %s]: ' - ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + '[Default is %s]: ') % ( + _DEFAULT_TRISYCL_INCLUDE_DIR) while True: trisycl_include_dir = get_from_env_or_user_or_default( @@ -1329,13 +1400,12 @@ def set_trisycl_include_dir(environ_cp): if os.path.exists(trisycl_include_dir): break - print('Invalid triSYCL include directory, %s cannot be found' - % (trisycl_include_dir)) + print('Invalid triSYCL include directory, %s cannot be found' % + (trisycl_include_dir)) # Set TRISYCL_INCLUDE_DIR environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', - trisycl_include_dir) + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) def set_mpi_home(environ_cp): @@ -1345,8 +1415,9 @@ def set_mpi_home(environ_cp): default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) def valid_mpi_path(mpi_home): - exists = (os.path.exists(os.path.join(mpi_home, 'include')) and - os.path.exists(os.path.join(mpi_home, 'lib'))) + exists = ( + os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) if not exists: print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % (os.path.join(mpi_home, 'include'), @@ -1395,16 +1466,22 @@ def set_other_mpi_vars(environ_cp): raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) -def set_grpc_build_flags(): - write_to_bazelrc('build --define grpc_no_ares=true') - - def set_system_libs_flag(environ_cp): syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') - syslibs = ','.join(sorted(syslibs.split(','))) - if syslibs and syslibs != '': + if syslibs: + if ',' in syslibs: + syslibs = ','.join(sorted(syslibs.split(','))) + else: + syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) + if 'PREFIX' in environ_cp: + write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) + if 'LIBDIR' in environ_cp: + write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) + if 'INCLUDEDIR' in environ_cp: + write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) + def set_windows_build_flags(environ_cp): """Set Windows specific build options.""" @@ -1421,14 +1498,20 @@ def set_windows_build_flags(environ_cp): # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 # Short object file path will be enabled by default. write_to_bazelrc('build --experimental_shortened_obj_file_path=true') + # When building zip file for some py_binary and py_test targets, don't + # include its dependencies. This is for: + # 1. Running python tests against the system installed TF pip package. + # 2. Avoiding redundant files in + # //tensorflow/tools/pip_package:simple_console_windows, + # which is a py_binary used during creating TF pip package. + # See https://github.com/tensorflow/tensorflow/issues/22390 + write_to_bazelrc('build --define=no_tensorflow_py_deps=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', - True, - ('Would you like to override eigen strong inline for some C++ ' - 'compilation to reduce the compilation time?'), - 'Eigen strong inline overridden.', - 'Not overriding eigen strong inline, ' + True, ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, ' 'some compilations could take more than 20 mins.'): # Due to a known MSVC compiler issue # https://github.com/tensorflow/tensorflow/issues/10521 @@ -1444,29 +1527,31 @@ def config_info_line(name, help_text): def main(): + global _TF_WORKSPACE_ROOT + global _TF_BAZELRC + parser = argparse.ArgumentParser() - parser.add_argument("--workspace", - type=str, - default=_TF_WORKSPACE_ROOT, - help="The absolute path to your active Bazel workspace.") + parser.add_argument( + '--workspace', + type=str, + default=os.path.abspath(os.path.dirname(__file__)), + help='The absolute path to your active Bazel workspace.') args = parser.parse_args() + _TF_WORKSPACE_ROOT = args.workspace + _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.15.0') - reset_tf_configure_bazelrc(args.workspace) + reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_AWS'] = '0' - environ_cp['TF_NEED_GCP'] = '0' - environ_cp['TF_NEED_HDFS'] = '0' - environ_cp['TF_NEED_JEMALLOC'] = '0' - environ_cp['TF_NEED_KAFKA'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' @@ -1476,40 +1561,24 @@ def main(): # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_ENABLE_XLA'] = '0' - environ_cp['TF_NEED_GDR'] = '0' - environ_cp['TF_NEED_VERBS'] = '0' environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): - environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + environ_cp['TF_ENABLE_XLA'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at # runtime to allow the Tensorflow testcases which compare numpy # results to Tensorflow results to succeed. if is_ppc64le(): - write_action_env_to_bazelrc("OMP_NUM_THREADS", 1) - - set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', - 'with_jemalloc', True) - set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', - 'with_gcp_support', True, 'gcp') - set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', - 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', - 'with_aws_support', True, 'aws') - set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', True, 'kafka') + write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) + + set_build_var(environ_cp, 'TF_NEED_IGNITE', 'Apache Ignite', + 'with_ignite_support', True, 'ignite') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', - False, 'xla') - set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', - False, 'gdr') - set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', - False, 'verbs') - set_build_var(environ_cp, 'TF_NEED_NGRAPH', 'nGraph', - 'with_ngraph_support', False, 'ngraph') + True, 'xla') set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1521,6 +1590,13 @@ def main(): else: set_trisycl_include_dir(environ_cp) + set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False) + if (environ_cp.get('TF_NEED_ROCM') == '1' and + 'LD_LIBRARY_PATH' in environ_cp and + environ_cp.get('LD_LIBRARY_PATH') != '1'): + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) + set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): @@ -1561,24 +1637,36 @@ def main(): write_to_bazelrc('build --config=download_clang') write_to_bazelrc('test --config=download_clang') + # SYCL / ROCm / CUDA are mutually exclusive. + # At most 1 GPU platform can be configured. + gpu_platform_count = 0 + if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': + gpu_platform_count += 1 + if environ_cp.get('TF_NEED_ROCM') == '1': + gpu_platform_count += 1 + if environ_cp.get('TF_NEED_CUDA') == '1': + gpu_platform_count += 1 + if gpu_platform_count >= 2: + raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' + 'At most 1 GPU platform can be configured.') + set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) - set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_system_libs_flag(environ_cp) if is_windows(): set_windows_build_flags(environ_cp) - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): + # Add a config option to build TensorFlow 2.0 API. + write_to_bazelrc('build:v2 --define=tf_api_version=2') + + if get_var(environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) @@ -1587,10 +1675,14 @@ def main(): # TODO(pcloudy): remove the following if check when they make sense on Windows if not is_windows(): print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See tools/bazel.rc for ' - 'more details.') + 'adding "--config=<>" to your build command. See .bazelrc for more ' + 'details.') config_info_line('mkl', 'Build with MKL support.') config_info_line('monolithic', 'Config for mostly static monolithic build.') + config_info_line('gdr', 'Build with GDR support.') + config_info_line('verbs', 'Build with libverbs support.') + config_info_line('ngraph', 'Build with Intel nGraph support.') + if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 386e0096ff705c2eaa98f42833ef650bac6fc8d8..9b62a504525d5377d4836e92bdf0e46f7fc3ef38 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -203,21 +203,6 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(jhseu): Enable on other platforms other than Linux. -config_setting( - name = "with_jemalloc_linux_x86_64", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "k8"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_jemalloc_linux_ppc64le", - define_values = {"with_jemalloc": "true"}, - values = {"cpu": "ppc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "with_default_optimizations", define_values = {"with_default_optimizations": "true"}, @@ -225,56 +210,8 @@ config_setting( ) config_setting( - name = "with_gcp_support", - define_values = {"with_gcp_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support", - define_values = {"with_hdfs_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support", - define_values = {"with_aws_support": "true"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_kafka_support", - define_values = {"with_kafka_support": "true"}, - visibility = ["//visibility:public"], -) - -# Crosses between platforms and file system libraries not supported on those -# platforms due to limitations in nested select() statements. -config_setting( - name = "with_gcp_support_windows_override", - define_values = {"with_gcp_support": "true"}, - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support_windows_override", - define_values = {"with_hdfs_support": "true"}, - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support_windows_override", - define_values = {"with_aws_support": "true"}, - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_kafka_support_windows_override", - define_values = {"with_kafka_support": "true"}, - values = {"cpu": "x64_windows"}, + name = "with_ignite_support", + define_values = {"with_ignite_support": "true"}, visibility = ["//visibility:public"], ) @@ -285,48 +222,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_gcp_support_android_override", - define_values = {"with_gcp_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support_android_override", - define_values = {"with_hdfs_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support_android_override", - define_values = {"with_aws_support": "true"}, - values = {"crosstool_top": "//external:android/crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_gcp_support_ios_override", - define_values = {"with_gcp_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_hdfs_support_ios_override", - define_values = {"with_hdfs_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_aws_support_ios_override", - define_values = {"with_aws_support": "true"}, - values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, - visibility = ["//visibility:public"], -) - config_setting( name = "with_xla_support", define_values = {"with_xla_support": "true"}, @@ -355,30 +250,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_jemalloc_linux_x86_64_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "k8", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "with_jemalloc_linux_ppc64le_dynamic", - define_values = { - "with_jemalloc": "true", - "framework_shared_object": "true", - }, - values = { - "cpu": "ppc", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "using_cuda_clang", define_values = { @@ -564,6 +435,7 @@ tf_cc_shared_object( "$(location //tensorflow/c:version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -588,6 +460,7 @@ tf_cc_shared_object( "$(location //tensorflow:tf_version_script.lds)", ], }), + visibility = ["//visibility:public"], deps = [ "//tensorflow:tf_exported_symbols.lds", "//tensorflow:tf_version_script.lds", @@ -608,6 +481,55 @@ exports_files( ], ) +genrule( + name = "install_headers", + srcs = [ + "//tensorflow/c:headers", + "//tensorflow/c/eager:headers", + "//tensorflow/cc:headers", + "//tensorflow/core:headers", + ], + outs = ["include"], + cmd = """ + mkdir $@ + for f in $(SRCS); do + d="$${f%/*}" + d="$${d#bazel-out*genfiles/}" + d="$${d#*external/eigen_archive/}" + + if [[ $${d} == *local_config_* ]]; then + continue + fi + + if [[ $${d} == external* ]]; then + extname="$${d#*external/}" + extname="$${extname%%/*}" + if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then + continue + fi + fi + + mkdir -p "$@/$${d}" + cp "$${f}" "$@/$${d}/" + done + """, + tags = ["manual"], + visibility = ["//visibility:public"], +) + +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + gen_api_init_files( name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], @@ -629,19 +551,6 @@ gen_api_init_files( root_init_template = "api_template.__init__.py", ) -genrule( - name = "root_init_gen", - srcs = select({ - "api_version_2": [":tf_python_api_gen_v2"], - "//conditions:default": [":tf_python_api_gen_v1"], - }), - outs = ["__init__.py"], - cmd = select({ - "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", - "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", - }), -) - py_library( name = "tensorflow_py", srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 53a72b84430ac703323e8235b4e3393d1c9898bc..2de740e145f93b151faf5c987808dbdf73fb4fd7 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -14,9 +14,9 @@ # ============================================================================== """Bring in all of the public TensorFlow interface into this module.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function import os as _os @@ -41,6 +41,11 @@ except (ImportError, AttributeError): from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader +# The templated code that replaces the placeholder above sometimes +# sets the __all__ variable. If it does, we have to be sure to add +# "contrib". +if '__all__' in vars(): + vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable @@ -51,10 +56,6 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disabl if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -del absolute_import -del division -del print_function - # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 43c279bd800d79eeaf9a25bbc1978148f93c0a50..17e2e292eb19029d279bc12a8328edadf96f1bb8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -246,6 +246,7 @@ tf_cc_test( ":c_api_experimental", ":c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index c195c9e01ca920c7234499b6e1d5e9cbf24056f3..d4b78138e93624a7e41e917f8210281b500661bc 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::FunctionDef; using tensorflow::Node; @@ -8508,6 +8509,20 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, VLOG(1) << "Enqueuing is done."; } +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { + tensorflow::ServerDef server_def; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, + &server_def)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for ServerDef: ", text_proto); + return nullptr; + } + status->status = tensorflow::Status(); + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(server_def, ret)); + return ret; +} + TFE_Context* TFE_CreateContextFromSession(TF_Session* session, TF_Status* status) { auto* opts = TFE_NewContextOptions(); @@ -8705,3 +8720,25 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, return createTFEDequeue(ctx, TF_VARIANT, queue, status); } + +static void CheckOk(TF_Status* status) { + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); +} + +void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { + auto* status = TF_NewStatus(); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::Tensor dst; + TF_CHECK_OK(TF_TensorToTensor(t, &dst)); + LOG(INFO) << dst.DebugString(); + + TF_DeleteTensor(t); + TF_DeleteStatus(status); +} + +TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, + const char* errMsg) { + status->status = tensorflow::errors::Internal(errMsg); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 522c91f67efdf10118268842dee3beb334fb720d..d98d532e32e891e21f5b7ba360c74c3256fb1947 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -131,6 +131,8 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, TF_Tensor* tensor, TF_Status* status); +// Create a serialized tensorflow.ServerDef proto. +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); // TODO: remove this API in favor of the next one. TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( @@ -174,6 +176,13 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_Session* session, int tensor_id, TF_Status* status); +// Prints `handle` in a human readable format to standard output for debugging. +TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( + TFE_TensorHandle* handle); + +TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, + const char* errMsg); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 30fcfd401d9d634962d64aaa3bf348de91f2ecae..c6effd39697e0397278770b53e98508074f99862 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace tensorflow { namespace { @@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { TF_DeleteStatus(s); } +TEST(CAPI_EXPERIMENTAL, GetServerDefTest) { + const string expected_text_proto(R"(cluster { + job { + name: "worker" + tasks { + key: 0 + value: "tpuserver:0" + } + tasks { + key: 1 + value: "localhost:1" + } + } +} +job_name: "worker" +task_index: 1 +protocol: "grpc" +)"); + + TF_Status* status = TF_NewStatus(); + TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK); + + ServerDef actual; + ASSERT_TRUE(actual.ParseFromArray(result->data, result->length)); + string actual_text_proto; + tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto); + EXPECT_EQ(expected_text_proto, actual_text_proto); + + const string malformed_text_proto(R"(cluster { + job { + name: "worker")"); + TF_Buffer* null_result = + TFE_GetServerDef(malformed_text_proto.c_str(), status); + EXPECT_NE(TF_GetCode(status), TF_OK); + EXPECT_TRUE(tensorflow::str_util::StrContains( + TF_Message(status), "Invalid text proto for ServerDef")); + EXPECT_EQ(null_result, nullptr); + + // Cleanup + TF_DeleteBuffer(result); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 37be52f57d865c1e59611540d5dab04b59e89444..3ee31a6a7ac641bbd3fc4c05568b61e433a1d523 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -68,7 +68,10 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/deepmind/courier:__pkg__", + "//tensorflow:internal", + ], deps = [ ":c_api", "//tensorflow/c:c_api", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 349d9bcd7ca3991c7c3621f347af6025778612b7..3554ec0bf3202b54bfc38d67e51b89df19832302 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -375,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { return result; } +int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + tensorflow::int64 result; + status->status = h->handle->NumElements(&result); + return result; +} + int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { @@ -567,6 +578,21 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length) { + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(data, length); + op->operation.MutableAttrs()->Set(attr_name, attr_value); +} + +void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, + TF_Status* status) { + tensorflow::Tensor t; + status->status = TF_TensorToTensor(tensor, &t); + if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); +} + void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 337447eec9581b01fa775affc49097986824a360..b2454d872207e26feb3764671474a5d87c01f84d 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, + TF_Status* status); // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, @@ -311,6 +313,14 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, + const char* attr_name, + TF_Tensor* tensor, + TF_Status* status); + TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 5607c9dcb0bbec72b2f86def3dd4e6590d73197b..008f088c2dcdd7d9114103516a4702e47a55c6de 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -99,8 +99,6 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TFE_OpAddInput(op, b, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); return op; diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index ce038a4b57b2699c6d09fcf75ef41cecec4e97b8..5ba55a203ff70cc64c07e96b5a869a1f11c9334e 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. -template -using OpTape = gtl::FlatMap>; +template +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template +template class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template +template class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,10 +127,10 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, - gtl::ArraySlice output_gradients, - std::vector* result); + Status ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template -bool GradientTape::ShouldRecord( +template +bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,20 +195,20 @@ bool GradientTape::ShouldRecord( return false; } -template -void GradientTape::Watch(int64 tensor_id) { +template +void GradientTape::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, +template +void GradientTape::RecordOperation( + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, + const std::function& backward_function_getter, const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(backward_function); return; } std::vector ids; @@ -229,16 +223,18 @@ void GradientTape::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), std::move(ids), backward_function_getter(), + backward_function_deleter}; } -template -void GradientTape::DeleteTrace(int64 tensor_id) { +template +void GradientTape::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +257,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +300,9 @@ void GradientTape::DeleteTrace(int64 tensor_id) { namespace { -template +template struct BackpropInitialState { - OpTape op_tape; + OpTape op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +318,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template -BackpropInitialState PrepareBackprop( +template +BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, const gtl::FlatSet& sources_set, - bool persistent_tape) { + OpTape* op_tape, + const gtl::FlatSet& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState result; + BackpropInitialState result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +379,9 @@ BackpropInitialState PrepareBackprop( return result; } -template +template std::vector InitialStack( - const OpTape& op_tape, + const OpTape& op_tape, const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { @@ -396,13 +392,13 @@ std::vector InitialStack( return result; } -template -Status InitialGradients(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, - const TensorTape& tensor_tape, - const OpTape& op_tape, - gtl::FlatMap>* result) { +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +412,10 @@ Status InitialGradients(const VSpace& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -440,6 +435,18 @@ Status InitialGradients(const VSpace& vspace, return Status::OK(); } +// TODO(agarwal): use an automatic mechanism for handling None arguments to +// gradient functions. +// +// Some gradient functions can accept None arguments for gradients. The +// following maps the operation name to the indices at which the corresponding +// gradient function can accept None values. e.g. FusedBatchNorm outputs 5 +// values and hence receives 5 gradient values during backprop. However the +// gradient function uses only the first of those values and ignores the rest. +// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient +// corresponding to index 0 is used, and the gradient values at indices 1-4 are +// ignored (and hence can be None). The backprop algorithm can then leverage +// this by not constructing zeros to pass for those indices. gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { static auto* const m = new gtl::FlatMap>({ {"SoftmaxCrossEntropyWithLogits", {1}}, @@ -457,16 +464,16 @@ gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template -Status GradientTape::ComputeGradient( - const VSpace& vspace, +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, gtl::ArraySlice target_tensor_ids, gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState state = PrepareBackprop( + BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -510,7 +517,7 @@ Status GradientTape::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = @@ -519,9 +526,7 @@ Status GradientTape::ComputeGradient( func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 8486b585c8587e18e8eea18a893fac0a40ff4a27..247236b760dd8c07bbb08426100b6a4d34296d2e 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status) { +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { tensorflow::CppShapeInferenceResult::HandleData handle_data; if (!handle_data.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 4bcb5bde62c8a4df4e68c1ce0daaf459434ceb5d..5cce84020bc68d912d259f51512341eb5f464a2c 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); void ExtendSession(TF_Session* session, TF_Status* status); // Returns the serialized CppShapeInferenceResult::HandleData proto for -// `output` if its a resource tensor, or otherwise returns the empty string. -std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // Sets `output` based on `proto`, which should be a serialized -// CppShapeInferenceResult::HandleData proto. +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string // because I couldn't get SWIG to work otherwise. -void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, - const void* proto, size_t proto_len, - TF_Status* status); +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f56521dac0374849081fe94f16feb08e55647b56..9d2208d84d7d0e96dc5b314f12a250395effdd73 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", + "cc_library_with_android_deps", "tf_cc_binary", + "tf_cc_test", "tf_copts", "tf_gen_op_wrappers_cc", - "cc_library_with_android_deps", + "transitive_hdrs", ) cc_library( @@ -410,6 +411,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", @@ -716,3 +718,26 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +transitive_hdrs( + name = "headers", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":cc_ops", + ":client_session", + ":coordinator", + ":gradient_checker", + ":gradients", + ":ops", + ":queue_runner", + ":remote_fused_graph_ops", + ":scope", + "//tensorflow/cc/profiler", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/cc/saved_model:signature_constants", + "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/cc/tools:freeze_saved_model", + ], +) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a32d1b1eb50fc715084f5ee663a732770db1883c..39593370d1c243e84dc5b6091724d1d404c102b0 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { } } - strings::StrAppend(&class_decl, "\n"); - - if (output_types.empty()) { - strings::StrAppend(&class_decl, " Operation operation;\n"); - } + strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); @@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const { string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); + strings::StrAppend(out, " this->operation = Operation(ret);\n"); + // No outputs. if (graph_op_def.output_arg_size() == 0) { - strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); + strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 7f6ac4cae78d8d6e118837fce9ae5270336cdc89..6abc9e268e3ac97379954a34017ddffa010db67f 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, refiner_(refiner), scope_used_(nullptr), colocation_constraints_(), - disable_shape_inference_(false) {} + disable_shape_inference_(refiner_ == nullptr) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); @@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_( clear_colocations ? std::unordered_set() : other.impl()->GetColocationConstraints(colocate_with_op)), disable_shape_inference_(other.impl()->disable_shape_inference_) {} +Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, + const string& assigned_device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(assigned_device), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); @@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->device_.empty()) { builder->Device(impl()->device_); } + if (!impl()->assigned_device_.empty()) { + builder->AssignedDevice(impl()->assigned_device_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const { return Scope(new Impl(*this, Impl::Tags::Device(), device)); } +Scope Scope::WithAssignedDevice(const string& assigned_device) const { + return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 30c32bd44b0f22d6b29dd3836d431807d0216818..e307d8989b6647dfac8d2691ed2171c86b7f3a7c 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -133,6 +133,10 @@ class Scope { /// the device field set to 'device'. Scope WithDevice(const string& device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 58adaef2e942a7fa6b0ce8d5534ac3e2fd380580..514e02e84146b6d95147d83182e5d9a07509cfa1 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -26,6 +26,8 @@ class ShapeRefiner; // graph, status, name_map, and refiner. // This is intended to enable the C API (which are used by other language // bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); class Scope::Impl { @@ -58,6 +60,7 @@ class Scope::Impl { enum class ExitOnError; enum class KernelLabel; enum class Colocate; + enum class AssignedDevice; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -74,6 +77,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); std::unordered_set GetColocationConstraints( const Operation& colocate_with_op) const; @@ -107,6 +111,7 @@ class Scope::Impl { const bool exit_on_error_ = false; const string kernel_label_ = ""; const string device_ = ""; + const string assigned_device_ = ""; const std::unordered_set colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb196189780037f66266484962ba0385e4..2a32a2ed6f7862a29f4ce3d1aba5fdbc86adc670 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2a958f54d50b59f0edaefb27edf0e86..f5a09e09dcda3e06c71d44d5fa5a1b121a9ade58 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = ops::internal::LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7a0932d44d405de0f2edf072f4760126bff36719..10fa33ab5e84dcbc1629bee6214e8969046f19c2 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -25,6 +25,7 @@ test_suite( ":test_graph_tfmatmul_test", ":test_graph_tfmatmulandadd_test", ":test_graph_tfsplits_test", + ":test_graph_tftop_k_test", ":tfcompile_test", ], ) @@ -42,6 +43,7 @@ py_binary( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:session", "//tensorflow/python:training", @@ -66,6 +68,7 @@ genrule( "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", + "test_graph_tftop_k.pb", ], # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # GPUs which might be present. This is important because builds may run @@ -208,6 +211,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tftop_k", + testonly = 1, + config = "test_graph_tftop_k.config.pbtxt", + cpp_class = "TopKComp", + graph = "test_graph_tftop_k.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -226,6 +240,7 @@ tf_cc_test( ":test_graph_tfmatmulandadd", ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", + ":test_graph_tftop_k", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9ec7df163b1425f917e9ec51559efad3e6f05e75..64b861a73091642b03573543a5c55618bf33915d 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib @@ -46,7 +47,7 @@ def tfadd(_): def tfadd_with_ckpt(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -61,7 +62,7 @@ def tfadd_with_ckpt(out_dir): def tfadd_with_ckpt_saver(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -142,6 +143,12 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tftop_k(_): + x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') + output = nn_ops.top_k(x, 2, name='values') + array_ops.identity(output[1], name='indices') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -163,6 +170,7 @@ def main(_): write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tftop_k, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b4ac2d7cbb517be841932b1cfae9e28decdf8d3 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt @@ -0,0 +1,13 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 5 } + } +} +fetch { + id { node_name: "values" } +} +fetch { + id { node_name: "indices" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 7ac90fb8a9c73bdbc149f263d7d229a6514769f8..f10852c7850f61bfd8b99fa9f1648202d182085e 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, TopK) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + TopKComp fn; + + fn.set_thread_pool(&device); + // x = [4, 1, 4, 4, 3] + fn.arg0(0) = 4; + fn.arg0(1) = 1; + fn.arg0(2) = 4; + fn.arg0(3) = 4; + fn.arg0(4) = 3; + + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + const int32 expected_values[] = {4, 4}; + const int32 expected_indices[] = {0, 2}; + EXPECT_EQ(expected_values[0], fn.result0(0)); + EXPECT_EQ(expected_values[1], fn.result0(1)); + EXPECT_EQ(expected_indices[0], fn.result1(0)); + EXPECT_EQ(expected_indices[1], fn.result1(1)); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 792b7fe14abf91626a0aeb75cdbe319b123ec10c..859c84bb91657422b830255b0217f8946d351458 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -273,6 +273,7 @@ def tf_library( "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 7d5db713f678b696131ff4074d54e3776f019e02..661b444a42eefadf52739d84483e8e26c07fadf5 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -50,7 +51,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -62,7 +63,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda([ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), @@ -76,7 +77,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -94,7 +95,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -111,7 +112,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep @@ -257,6 +258,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -280,7 +282,7 @@ cc_library( deps = [ ":common", ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -322,6 +324,8 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -341,7 +345,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -359,18 +363,20 @@ tf_cc_test( cc_library( name = "compilation_passes", srcs = [ - "build_xla_launch_ops_pass.cc", + "build_xla_ops_pass.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", + "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", ], hdrs = [ - "build_xla_launch_ops_pass.h", + "build_xla_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", + "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -380,12 +386,16 @@ cc_library( ":shape_inference_helpers", ":union_find", ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -397,6 +407,9 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -456,7 +469,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -467,6 +480,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -474,13 +488,16 @@ tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ + "build_xla_ops_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", + "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":node_matchers", ":xla_cluster_util", ":xla_gpu_device", "//tensorflow/cc:cc_ops", @@ -489,8 +506,10 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -500,6 +519,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -518,7 +538,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -593,6 +613,44 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "node_matchers", + testonly = True, + srcs = ["node_matchers.cc"], + hdrs = ["node_matchers.h"], + deps = [ + "//tensorflow/cc:ops", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "node_matchers_test", + srcs = ["node_matchers_test.cc"], + deps = [ + ":node_matchers", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:ops", + "//tensorflow/core:test_main", + ], +) + +tf_custom_op_py_library( + name = "xla_ops_py", + kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], + visibility = [ + ":friends", + ], + deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc deleted file mode 100644 index b17ff589e2597f8d1b5e61f4eaaed7d6ebe6214c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/public/version.h" - -namespace tensorflow { - -static Status BuildLaunchNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, - Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("XlaLaunch"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - AddNodeAttr("Tresults", result_dtypes, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); - - Status status; - *node = graph->AddNode(def, &status); - return status; -} - -static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { - VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; - - int num_constant_args, num_resource_args; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); - - if (num_constant_args < 0 || num_resource_args < 0 || - num_constant_args + num_resource_args > node->num_inputs()) { - return errors::InvalidArgument( - "Invalid number of constant/resource arguments to XLA kernel."); - } - const int num_nonconst_args = - node->num_inputs() - num_constant_args - num_resource_args; - - DataTypeVector const_dtypes(node->input_types().begin(), - node->input_types().begin() + num_constant_args); - DataTypeVector arg_dtypes( - node->input_types().begin() + num_constant_args, - node->input_types().begin() + num_constant_args + num_nonconst_args); - - // Build a XlaLaunch operator to execute the function body. - Node* launch_node; - TF_RETURN_IF_ERROR(BuildLaunchNode( - graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, - node->output_types(), graph, &launch_node)); - launch_node->set_assigned_device_name(node->assigned_device_name()); - - // Copy incoming edges to the launch node. - for (const Edge* edge : node->in_edges()) { - if (edge->IsControlEdge()) { - graph->AddControlEdge(edge->src(), launch_node); - } else { - graph->AddEdge(edge->src(), edge->src_output(), launch_node, - edge->dst_input()); - } - } - - // Copy outgoing edges to the launch node. - std::vector out_edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : out_edges) { - Node* dst = edge->dst(); - int src_output = edge->src_output(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (edge->IsControlEdge()) { - graph->AddControlEdge(launch_node, dst); - } else { - graph->AddEdge(launch_node, src_output, dst, dst_input); - } - } - graph->RemoveNode(node); - - return Status::OK(); -} - -Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } - - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); - } - } - - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, - options.flib_def); - } - return Status::OK(); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..5974696b7751d69eb27141173fdab14313925ee9 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "absl/algorithm/container.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { +void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + // TODO(sanjoy): This does not update NodeDef inputs. To be able to update + // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up + // the NodeDef inputs to the function call nodes. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); + g->RemoveEdge(edge); + } +} + +struct XlaClusterInfo { + std::vector constant_inputs; + std::vector non_constant_inputs; + std::vector resource_inputs; + NameAttrList function; +}; + +Output IncomingEdgeAsOutput(const Edge* e) { + return Output(e->src(), e->src_output()); +} + +Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { + int num_constant_inputs, num_resource_inputs; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs)); + + if (num_constant_inputs < 0 || num_resource_inputs < 0 || + num_constant_inputs + num_resource_inputs > n->num_inputs()) { + return errors::InvalidArgument( + "Invalid number of constant/resource arguments to XLA kernel."); + } + + int num_non_constant_inputs = + n->num_inputs() - num_constant_inputs - num_resource_inputs; + + std::vector input_edges_vector; + TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector)); + absl::Span input_edges(input_edges_vector); + + absl::c_transform(input_edges.subspan(0, num_constant_inputs), + std::back_inserter(result->constant_inputs), + IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs, num_non_constant_inputs), + std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs + num_non_constant_inputs, + num_resource_inputs), + std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput); + + result->function.set_name(n->type_string()); + *result->function.mutable_attr() = n->def().attr(); + return Status::OK(); +} + +Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { + for (const Edge* e : from->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), to); + } + } + + return Status::OK(); +} + +Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { + Status status; + Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(n->name()) + .WithDevice(n->requested_device()) + .WithAssignedDevice(n->assigned_device_name()); + + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + + ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), + /*constants=*/cluster_info.constant_inputs, + /*args=*/cluster_info.non_constant_inputs, + /*resources=*/cluster_info.resource_inputs, + cluster_info.function); + TF_RETURN_IF_ERROR( + CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); + + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); + + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + g->RemoveNode(n); + + return Status::OK(); +} +} // namespace + +Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + for (Node* n : graph->op_nodes()) { + // In all cases, only try to compile computational nodes. + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + continue; + } + + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + if (IsXlaCompiledKernel(*n)) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n)); + } + } + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + } + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h similarity index 71% rename from tensorflow/compiler/jit/build_xla_launch_ops_pass.h rename to tensorflow/compiler/jit/build_xla_ops_pass.h index 1dfea93f02081404c5c3c6686a8b28a8530ae8a3..1dd38fa95186dfbe458166caa23a131fbe3c9510 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -13,19 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -class BuildXlaLaunchOpsPass : public GraphOptimizationPass { +// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and +// executes (using XLA) TF function calls marked with "_XlaCompiledKernel". +class BuildXlaOpsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d56db7b6bc12938b2de9df02b97ff0ca6a42e54 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::NodeWith; +using ::tensorflow::testing::matchers::Op; + +Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { + auto graph = absl::make_unique(OpRegistry::Global()); + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : graph->nodes()) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = &graph; + BuildXlaOpsPass pass; + TF_RETURN_IF_ERROR(pass.Run(opt_options)); + *result = std::move(graph); + return Status::OK(); +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, int num_constant_args, + int num_resource_args, Node** result) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); + Status s; + *result = graph->AddNode(call_node, &s); + return s; +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + return MakeXlaCompiledKernel(graph, callee_name, node_name, + /*num_constant_args=*/0, /*num_resource_args=*/0, + result); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +TEST(BuildXlaOps, ControlDepsPreserved) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); +} + +TEST(BuildXlaOps, CleanFailureOnBogusAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK( + MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr graph; + Status failure_status = BuildXlaOps(root, &graph); + ASSERT_FALSE(failure_status.ok()); + EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 56b034a30b7bddb023e54ead22c91a7a18095d2d..6f1ff85f24a4c1fd3e6d54fcff9f8868aee6f750 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 9128b48da3fe9dd3d85d146e16c153c1b3bebf4c..b7ae7fbeb3912882368dc828e8d6fcd50735b04e 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -296,7 +299,7 @@ class SymbolPredicate : public Predicate { template /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector stack; stack.push_back(p); @@ -383,6 +386,8 @@ class PredicateFactory { } Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); + Predicate* MakeInternedAndOr(std::vector simplified_ops, + Predicate::Kind pred_kind); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -417,24 +422,53 @@ class PredicateFactory { } }; - gtl::FlatMap, - HashSignatureForAndOr> + absl::flat_hash_map, + HashSignatureForAndOr> interned_and_or_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_not_instances_; - gtl::FlatMap> + absl::flat_hash_map> interned_and_rec_instances_; - gtl::FlatMap, - HashSignatureForSymbol> + absl::flat_hash_map, + HashSignatureForSymbol> interned_symbol_instances_; }; +Predicate* PredicateFactory::MakeInternedAndOr( + std::vector simplified_ops, Predicate::Kind pred_kind) { + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it != interned_and_or_instances_.end()) { + return it->second.get(); + } + + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + absl::Span operands_slice = simplified_ops; + std::unique_ptr new_pred = + pred_kind == Predicate::Kind::kAnd + ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + interned_and_or_instances_.emplace( + SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred)); + return new_pred_ptr; +} + // Common code to create AndPredicate or OrPredicate instances. Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; - gtl::FlatSet simplified_ops_set; + Predicate::Kind other_pred_kind = + is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; + absl::flat_hash_set simplified_ops_set; std::vector simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -459,7 +493,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet negated_ops; + absl::flat_hash_set negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast(*op).operand()); @@ -472,30 +506,63 @@ Predicate* PredicateFactory::MakeAndOrImpl( } } - std::stable_sort( - simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + // If all ops contain the same subop, then factor it out thanks to the + // distributive property. Such as: + // - (A & B) | (A & C) | (A & D) => A & (B | C | D) + // - (A | B) & (A | C) & (A | D) => A | (B & C & D) + // + // First find any predicates contained in all subops. + std::vector common_inner_operands; + absl::flat_hash_set common_inner_operands_set; + for (Predicate* op : simplified_ops) { + if (op->kind() != other_pred_kind) { + common_inner_operands.clear(); + break; + } - auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); - if (it == interned_and_or_instances_.end()) { - simplified_ops.shrink_to_fit(); - // NB! Because we'll use a non-owning reference to simplified_ops in the - // key for interned_and_or_instances_ we need to be careful to std::move() - // it all the way through. - absl::Span operands_slice = simplified_ops; - std::unique_ptr new_pred = - is_and ? Make(std::move(simplified_ops)) - : Make(std::move(simplified_ops)); + if (common_inner_operands.empty()) { + common_inner_operands.insert(common_inner_operands.end(), + op->GetOperands().begin(), + op->GetOperands().end()); + } else { + std::vector sub_ops_intersection; + common_inner_operands.clear(); + absl::c_copy_if(op->GetOperands(), + std::back_inserter(common_inner_operands), + [&](Predicate* sub_op) { + return common_inner_operands_set.count(sub_op) == 1; + }); + } + if (common_inner_operands.empty()) break; + common_inner_operands_set.clear(); + common_inner_operands_set.insert(common_inner_operands.begin(), + common_inner_operands.end()); + } - Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_or_instances_ - .emplace(SignatureForAndOr(pred_kind, operands_slice), - std::move(new_pred)) - .second); - return new_pred_ptr; - } else { - return it->second.get(); + if (common_inner_operands.empty()) { + return MakeInternedAndOr(std::move(simplified_ops), pred_kind); } + + // For all predicates that can be factored out, remove them and recreate the + // subops. + std::vector factored_ops; + for (Predicate* op : simplified_ops) { + std::vector new_sub_op_ops; + absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops), + [&](Predicate* sub_op) { + return std::find(common_inner_operands.begin(), + common_inner_operands.end(), + sub_op) == common_inner_operands.end(); + }); + factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and)); + } + + Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and); + std::vector outer_ops; + outer_ops.push_back(new_inner_op); + outer_ops.insert(outer_ops.end(), common_inner_operands.begin(), + common_inner_operands.end()); + return MakeAndOrImpl(outer_ops, !is_and); } class DeadnessAnalysisImpl : public DeadnessAnalysis { @@ -507,12 +574,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; - gtl::FlatMap PredicateMapAsString() const; + absl::flat_hash_map PredicateMapAsString() + const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -549,7 +618,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; - gtl::FlatMap predicate_map_; + absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; bool vlog_; }; @@ -558,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -std::vector DeadnessAnalysisImpl::GetIncomingPreds( - Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { - std::vector incoming_preds; +Status DeadnessAnalysisImpl::GetInputPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, + std::vector* result) { + result->clear(); for (const Edge* in_edge : n->in_edges()) { bool should_process = edge_kind == EdgeKind::kDataAndControl || @@ -569,17 +639,27 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()) << n->name(); - incoming_preds.push_back(it->second); + if (it == predicate_map_.end()) { + GraphCycles graph_cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + + // If we didn't return with an error above then the graph is probably + // fine and we have a bug in deadness analysis. + return errors::Internal("Could not find input ", in_edge->DebugString(), + " to ", n->name(), + " when visiting the graph in post-order. Most " + "likely indicates a bug in deadness analysis."); + } + result->push_back(it->second); } } - return incoming_preds; + return Status::OK(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, std::vector* should_revisit) { - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( @@ -608,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -const Edge* FindUniqueBackedge(Node* merge) { +Status CreateMultipleNextIterationInputsError(Node* merge) { + std::vector backedges; + for (const Edge* backedge : merge->in_edges()) { + if (backedge->src()->IsNextIteration()) { + backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); + } + } + return errors::InvalidArgument( + "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), + ": \n", absl::StrJoin(backedges, "\n"), + "\nMerge nodes can have at most one incoming NextIteration edge."); +} + +Status FindUniqueBackedge(Node* merge, const Edge** result) { + *result = nullptr; CHECK(merge->IsMerge()); - const Edge* result = nullptr; for (const Edge* e : merge->in_edges()) { if (e->src()->IsNextIteration()) { - CHECK_EQ(result, nullptr) - << "Multiple backedges to " << merge->DebugString(); - result = e; + if (*result != nullptr) { + return CreateMultipleNextIterationInputsError(merge); + } + *result = e; } } - return result; + return Status::OK(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -697,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return Status::OK(); } + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); + // We're visiting this merge for the first time and it is a acyclic merge. - Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + Predicate* input_data_pred = + predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -710,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // of an unvisited backedge. Try to pattern match the predicate expression // for that backedge (which should be visited now) into an and recurrence // for the merge node. - if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + const Edge* unique_backedge; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge)); + if (unique_backedge) { if (Predicate* step = DeduceStepPredicate( &predicate_factory_, it->second, predicate_map_[InputEdgeToTensorId(unique_backedge)])) { @@ -741,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. - std::vector input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); SetPredicate(n, {0, Graph::kControlSlot}, @@ -754,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, Status DeadnessAnalysisImpl::HandleGeneric(Node* n, std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. - Predicate* pred = predicate_factory_.MakeAndPredicate( - GetIncomingPreds(n, EdgeKind::kDataAndControl)); + std::vector input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); + Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { SetPredicate(n, output_idx, pred, should_revisit); } @@ -912,9 +1012,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } -gtl::FlatMap +absl::flat_hash_map DeadnessAnalysisImpl::PredicateMapAsString() const { - gtl::FlatMap result; + absl::flat_hash_map result; std::vector tensor_ids; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 3df2679c629ce801fc6c9006415dcd27b40c078e..354782374ad070a3d19ddd68bfb986d5a8285e51 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = gtl::FlatMap; +using PredicateMapTy = absl::flat_hash_map; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // Returns a map describing the predicate each Tensor was mapped to. For diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 28a56044d5e3795fc3ecf5d1092491b87cb90f01..617e31488c7daeb714c0ff7056b786e4eaf7873f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); } -TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { - // This demonstrates one of the weaknesses in the current approach -- since we - // only do some basic simplifications we can't see that "(A|B)&C" == - // "(A&C)|(B&C)". +TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { + // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "A"); + ops::Switch sw_1 = CreateSwitch(root, "B"); + Output add0 = + ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true); + Output add1 = + ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false); + ops::Merge or2(root.WithOpName("or2"), {add0, add1}); + Output add3 = + ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false); + ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true}); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true"); +} + +TEST(DeadnessAnalysisTest, AndOrDistributive) { + // (A|B)&C == (A&C)|(B&C) Scope root = Scope::NewRootScope().ExitOnError(); ops::Switch sw_0 = CreateSwitch(root, "0"); @@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); } TEST(DeadnessAnalysisTest, Ternary) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index ae7a22f4516fc6c87c0c555214eacac71f2ea0d7..da27f837e88fc3f57f865211929ec9cb1a1af779 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/session_options.h" @@ -58,10 +59,27 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; +void SortControlInputs(GraphDef* gdef) { + int64 num_nodes = gdef->node_size(); + for (int64 i = 0; i < num_nodes; ++i) { + NodeDef* node = gdef->mutable_node(i); + // Stable sort control inputs and leave the order of data inputs unchanged. + std::stable_sort(node->mutable_input()->begin(), + node->mutable_input()->end(), + [](const string& a, const string& b) { + bool a_is_control = absl::StartsWith(a, "^"); + bool b_is_control = absl::StartsWith(b, "^"); + return (!a_is_control && b_is_control) || + (a_is_control && b_is_control && a < b); + }); + } +} + namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -84,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector>& src_arg_pairs) { - gtl::FlatSet guaranteed_const_nodes; + absl::flat_hash_set guaranteed_const_nodes; std::vector srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { @@ -731,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { graph_->set_versions(graph_in->versions()); } + // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is + // determined. In case of hard placement, ensure all the encapsulated nodes + // have the same requested device, which in turn will be the requested device + // for the entire encapsulated subgraph. In case of soft placement, use a + // deterministic approach to fill in the requested device. Handle co-location + // constraints similarly if they exist. if (device_.empty()) { device_ = node->assigned_device_name().empty() ? node->requested_device() @@ -1340,28 +1364,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { - Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no group_attribute. - attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - } - bool has_group_attr = s.ok(); - s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, - outside_compilation_attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no outside_compilation attribute. - outside_compilation_attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - if (!has_group_attr) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); + AttrSlice attrs = node->attrs(); + attr->clear(); + outside_compilation_attr->clear(); + bool found_group_attribute = false; + bool found_outside_compilation_attribute = false; + for (const auto& node_attr : attrs) { + if (node_attr.first == group_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *attr = node_attr.second.s(); + found_group_attribute = true; + } else if (node_attr.first == outside_compilation_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *outside_compilation_attr = node_attr.second.s(); + found_outside_compilation_attribute = true; } + if (found_group_attribute && found_outside_compilation_attribute) break; + } + + if (found_outside_compilation_attribute && !found_group_attribute) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } else { + return Status::OK(); } - return Status::OK(); } bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 926589546fec72048485d30966f31b24e44b1245..90354a801afb26b003e00c4529069fdc61bbca32 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; +// Sorts each node's control inputs by their names. This guarantees that for two +// structually equivalent GraphDefs, we get the same traversal ordering on +// node's control input fields. +// TODO(hpucha): Move the utilities to a more appropriate place. +void SortControlInputs(GraphDef* gdef); + class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ce6fa73fc448ca83fa392aa909cb385453eb8b6 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -0,0 +1,362 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" + +namespace tensorflow { + +const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = + "_xla_compile_id"; + +namespace { + +const char* const kXlaClusterOutput = "XlaClusterOutput"; + +// Checks if a graph node is marked to be a guaranteed constant. +bool is_guaranteed_constant(const Node& n) { + bool guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) + .ok()) { + return false; + } + return guaranteed_constant; +} + +// Finds the `index` of an _Arg or _Retval node. +Status GetIndexAttr(const Node& n, int num_args, int* index) { + TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); + if (*index < 0 || *index >= num_args) { + return errors::InvalidArgument("Invalid ", n.type_string(), " number ", + *index); + } + return Status::OK(); +} + +// Returns the data type of the destination of an edge. +DataType EdgeType(const Edge* edge) { + return edge->dst()->input_type(edge->dst_input()); +} + +// Adds the control inputs of `node` to `*deps`. +void AddControlInputs(const Node& node, absl::flat_hash_set* deps) { + for (const Edge* edge : node.in_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->src()); + } + } +} + +// Adds the control outputs of `node` to `*deps`. +void AddControlOutputs(const Node& node, absl::flat_hash_set* deps) { + for (const Edge* edge : node.out_edges()) { + if (edge->IsControlEdge()) { + deps->insert(edge->dst()); + } + } +} + +// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts +// the arguments into the order expected by XlaLaunch computations: +// 1) arguments +// 2) resource variable arguments +// See the documentation of EncapsulateSubgraphsInFunctions for the meaning +// of the arguments. +// +// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. +Status RewriteSubgraph(const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + const int num_args = input_permutation->size(); + const int num_retvals = output_permutation->size(); + + std::vector args; + std::vector retvals; + args.reserve(num_args); + retvals.reserve(num_retvals); + for (Node* n : graph->nodes()) { + if (n->type_string() == "_Arg") { + // Check if this is a guaranteed constant. + if (is_guaranteed_constant(*n)) { + return errors::InvalidArgument( + "Guaranteed constants are not supported (", n->name(), ")"); + } + args.push_back(n); + } else if (n->type_string() == "_Retval") { + retvals.push_back(n); + } + } + + if (std::find(args.begin(), args.end(), nullptr) != args.end()) { + return errors::InvalidArgument("Missing or non-consecutive arguments"); + } + + // Reorders the arguments. + std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { + // Non-resources appear before resources + bool a_is_resource = (a->output_type(0) == DT_RESOURCE); + bool b_is_resource = (b->output_type(0) == DT_RESOURCE); + // Uses the name as a tiebreaker so the output is deterministic. + StringPiece a_name(a->name()); + StringPiece b_name(b->name()); + return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); + }); + + // Sorts the retvals by name so the order is deterministic. + std::sort(retvals.begin(), retvals.end(), + [](Node* a, Node* b) { return a->name() < b->name(); }); + + // Computes the permutation to produce the correct argument order, and update + // the argument indices. + int variable_start_index = num_args; + for (int i = 0; i < num_args; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); + if (args[i]->output_type(0) == DT_RESOURCE && + variable_start_index == num_args) { + variable_start_index = i; + } + (*input_permutation)[index] = i; + args[i]->AddAttr("index", i); + } + VLOG(4) << "variable_start_index: " << variable_start_index; + + // Computes the permutation to produce the correct retval order, and update + // the argument indices. + for (int i = 0; i < num_retvals; ++i) { + int index; + TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); + (*output_permutation)[index] = i; + retvals[i]->AddAttr("index", i); + } + + AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), + call_def); + AddNodeAttr("_variable_start_index", variable_start_index, call_def); + + // Uniquify the function name. + GraphDef gdef; + graph->ToGraphDef(&gdef); + + // Before serialization, sort each node's control inputs to achieve + // determinism. Sorting control inputs could help (but not necessarily) create + // a deterministic serialization and fingerprint. Other sources of + // nondeterminism include unstable node ordering. + SortControlInputs(&gdef); + // Fingerprint the function. + // Nondeterminism in serialization would not lead to incorrect results, but + // may cause spurious cache misses. DeterministicSerialization is a + // best-effort deterministic serialization. + string serialized; + TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); + uint64 fingerprint = Fingerprint64(serialized); + LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); + return Status::OK(); +} + +} // namespace + +/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Check for undeclared outputs before Encapsulation, so we can give a better + // error message. + // TODO(phawkins): merge this with the encapsulation code to avoid the extra + // O(n) pass over the edges. + for (const Edge* e : (*graph)->edges()) { + if (!e->IsControlEdge() && + e->src()->attrs().Find(kXlaClusterAttr) != nullptr && + e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && + e->dst()->type_string() != kXlaClusterOutput) { + return errors::InvalidArgument( + "Undeclared output of XLA computation. A common cause of this error " + "is variable initializers that depend on the XLA computation. Edge: ", + e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", + e->dst_input()); + } + } + + auto output = absl::make_unique((*graph)->op_registry()); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, "", **graph, RewriteSubgraph, + /*reuse_existing_functions=*/true, &output, flib_def), + "EncapsulateXlaComputationsPass failed"); + graph->swap(output); + return Status::OK(); +} + +/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( + Graph* graph) { + // Finds all of the XlaLaunch function calls, to avoid mutating the graph + // while iterating. + std::vector launch_nodes; + for (Node* n : graph->nodes()) { + string name; + if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + launch_nodes.push_back(n); + } + } + + // Replaces each launch function call together with its neighboring + // XlaClusterOutput nodes with a XlaLaunch node. + for (Node* launch : launch_nodes) { + int variable_start_index; + TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", + &variable_start_index)); + + std::vector in_edges; + TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); + + const int num_inputs = in_edges.size(); + const int num_variables = num_inputs - variable_start_index; + const int num_args = variable_start_index; + + VLOG(4) << "Launch node '" << launch->name() << "'" + << " input edges: " << in_edges.size() << " num_args: " << num_args + << " num_variables: " << num_variables; + + std::vector nodes_to_remove = {launch}; + + // Data and control inputs to the new XlaLaunch node. + std::vector> data_inputs(num_inputs); + absl::flat_hash_set control_inputs; + DataTypeVector arg_types(num_args); + + AddControlInputs(*launch, &control_inputs); + + for (int i = 0; i < num_args; ++i) { + const Edge* edge = in_edges[i]; + data_inputs[i] = {edge->src(), edge->src_output()}; + arg_types[i] = EdgeType(edge); + } + + // Appends the variable inputs. + for (int i = 0; i < num_variables; ++i) { + int pos = variable_start_index + i; + const Edge* edge = in_edges[pos]; + data_inputs[pos] = {edge->src(), edge->src_output()}; + } + + // Outputs. + const int num_outputs = launch->output_types().size(); + absl::flat_hash_set control_outputs; + std::vector>> data_outputs(num_outputs); + DataTypeVector output_types(num_outputs); + + for (const Edge* le : launch->out_edges()) { + if (le->IsControlEdge()) { + control_outputs.insert(le->dst()); + } else { + TF_RET_CHECK(le->src_output() < num_outputs); + Node* output_node = le->dst(); + + TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) + << le->DebugString(); + nodes_to_remove.push_back(output_node); + + for (const Edge* oe : output_node->out_edges()) { + TF_RET_CHECK(!oe->IsControlEdge()); + data_outputs[le->src_output()].push_back( + {oe->dst(), oe->dst_input()}); + } + output_types[le->src_output()] = output_node->input_type(0); + + AddControlOutputs(*output_node, &control_outputs); + } + } + + NodeDef def; + def.set_name(launch->name()); + + // Target the XLA CPU/GPU backends. + VLOG(2) << "Replacing with XlaLaunch"; + VLOG(2) << "Device is " << launch->requested_device(); + def.set_op("XlaLaunch"); + def.set_device(launch->requested_device()); + AddNodeAttr("Tconstants", DataTypeVector{}, &def); + AddNodeAttr("Targs", arg_types, &def); + AddNodeAttr("Nresources", num_variables, &def); + AddNodeAttr("Tresults", output_types, &def); + NameAttrList function; + function.set_name(launch->type_string()); + AddNodeAttr("function", function, &def); + + for (Node* node : nodes_to_remove) { + VLOG(2) << "Deleting node " << node->DebugString(); + // Ensure that we do not attempt to add control edges to nodes that are + // deleted. + control_inputs.erase(node); + control_outputs.erase(node); + graph->RemoveNode(node); + } + + Status status; + Node* xla_launch = graph->AddNode(def, &status); + if (!status.ok()) { + return status; + } + for (int i = 0; i < data_inputs.size(); ++i) { + graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, + i); + } + for (Node* n : control_inputs) { + graph->AddControlEdge(n, xla_launch); + } + for (int i = 0; i < data_outputs.size(); ++i) { + for (const auto& successor : data_outputs[i]) { + graph->AddEdge(xla_launch, i, successor.first, successor.second); + } + } + for (Node* n : control_outputs) { + graph->AddControlEdge(xla_launch, n); + } + } + return Status::OK(); +} + +Status EncapsulateXlaComputationsPass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "EncapsulateXlaComputations(): " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + VLOG(1) << "EncapsulateXlaComputations() half-way: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", + **options.graph, options.flib_def); + + TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + VLOG(1) << "EncapsulateXlaComputations() finished: " + << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", + **options.graph, options.flib_def); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..99e9dfd598f29697dd009aa32f5317ed3dc647ae --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + + namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + static const char* const kXlaClusterAttr; // _xla_compile_id + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static Status BuildXlaLaunchOps(Graph* graph); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..22531a4acea3f130175c7cb2e03fcb7570926094 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -0,0 +1,349 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" + +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +static std::unique_ptr MakeOuterGraph( + const FunctionLibraryDefinition& flib_def, const string& function) { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NodeDef def; + TF_CHECK_OK( + NodeDefBuilder("launch0", function, &flib_def) + .Input(a.node()->name(), 0, DT_INT32) + .Input(b.node()->name(), 0, DT_FLOAT) + .Input(c.node()->name(), 0, DT_INT32) + .Input(d.node()->name(), 0, DT_FLOAT) + .Input(u.node()->name(), 0, DT_RESOURCE) + .Input(v.node()->name(), 0, DT_RESOURCE) + .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") + .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") + .Attr("_variable_start_index", 4) + .Finalize(&def)); + + Status status; + Node* launch = scope.graph()->AddNode(def, &status); + TF_CHECK_OK(status); + TF_CHECK_OK(scope.DoShapeInference(launch)); + scope.graph()->AddEdge(a.node(), 0, launch, 0); + scope.graph()->AddEdge(b.node(), 0, launch, 1); + scope.graph()->AddEdge(c.node(), 0, launch, 2); + scope.graph()->AddEdge(d.node(), 0, launch, 3); + scope.graph()->AddEdge(u.node(), 0, launch, 4); + scope.graph()->AddEdge(v.node(), 0, launch, 5); + scope.graph()->AddEdge(w.node(), 0, launch, 6); + + auto out0 = + ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); + auto out1 = + ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); + auto out2 = + ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); + auto out3 = + ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +// Makes an encapsulate body graph for use in tests. +static std::unique_ptr MakeBodyGraph() { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); + auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); + + auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); + auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); + auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); + add_attrs(b_identity.node()); + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, arg3); + add_attrs(g.node()); + + auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), + b_identity, 0); + auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); + auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); + auto out3 = + ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(scope.ToGraph(graph.get())); + return graph; +} + +TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { + // Test that control edge insertion order doesn't affect the cache key + // (cluster name) generated by TPU encapsulate pass. + auto get_serialized_graph = [](bool control_input_reversed, + bool operand_reversed) -> string { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); + auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); + + ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) + : ops::Add(scope.WithOpName("E"), a1, a0); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, + "launch0"); + }; + add_attrs(e.node()); + + TF_CHECK_OK(scope.ToGraph(graph.get())); + auto get_node_in_graph = [&graph](Node* node) { + return graph->FindNodeId(node->id()); + }; + // Insert control edge in different order. The order should not affect + // the encapsulated or serialized graph. + if (!control_input_reversed) { + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + } else { + graph->AddControlEdge(get_node_in_graph(a1.node()), + get_node_in_graph(e.node()), true); + graph->AddControlEdge(get_node_in_graph(a0.node()), + get_node_in_graph(e.node()), true); + } + } + TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + GraphDef gdef; + graph->ToGraphDef(&gdef); + // Before serialization, sort control inputs first to remove + // nondeterminism. + SortControlInputs(&gdef); + string serialized; + SerializeToStringDeterministic(gdef, &serialized); + return serialized; + }; + + // Changing the order of control input shouldn't affect the graph generated. + EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, + /*operand_reversed=*/false), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); + + // Changing the order of data input should affect the graph generated. + EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/true), + get_serialized_graph(/*control_input_reversed=*/false, + /*operand_reversed=*/false)); +} + +TEST(EncapsulateXlaComputations, Encapsulate) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + std::unique_ptr graph(new Graph(&flib_def)); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + auto add_attrs = [](Node* node) { + node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); + }; + + auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); + add_attrs(b_identity.node()); + + auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); + add_attrs(read_u.node()); + auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); + add_attrs(read_v.node()); + auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); + add_attrs(read_w.node()); + + auto e = ops::Add(scope.WithOpName("E"), a, c); + add_attrs(e.node()); + auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); + add_attrs(f.node()); + auto g = ops::Add(scope.WithOpName("G"), f, d); + add_attrs(g.node()); + + auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); + auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); + auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); + auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); + + auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); + auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); + auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); + auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); + auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); + auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + std::unique_ptr graph_copy(new Graph(&flib_def)); + CopyGraph(*graph, graph_copy.get()); + + TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); + + std::unordered_map index = BuildNodeIndex(*graph); + string function = index.at("launch0")->type_string(); + + // Tests the outer graph is as expected. + { + std::unique_ptr outer = MakeOuterGraph(flib_def, function); + GraphDef expected_def; + outer->ToGraphDef(&expected_def); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); + } + + // Tests the encapsulated body graph is as expected. + { + std::unique_ptr body = MakeBodyGraph(); + GraphDef expected_body_def; + body->ToGraphDef(&expected_body_def); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, + DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); + } + + // Encapsulates the same computation again, verifies we reuse the same + // function. Encapsulation should be deterministic to avoid recompilation. + TF_ASSERT_OK( + EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); + std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + string function_copy = index_copy.at("launch0")->type_string(); + EXPECT_EQ(function, function_copy); +} + +TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { + std::unique_ptr body_graph = MakeBodyGraph(); + FunctionDefLibrary flib; + TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + + std::unique_ptr graph = MakeOuterGraph(flib_def, "launch0"); + TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); + + Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); + auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); + auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); + auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); + auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); + auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); + auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); + + NameAttrList function; + function.set_name("launch0"); + auto launch = ops::XlaLaunch( + scope.WithOpName("launch0").WithDevice("/gpu:0"), + std::initializer_list{}, std::initializer_list{a, b, c, d}, + std::initializer_list{u, v, w}, + DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); + + auto consumer0_a = + ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); + auto consumer0_b = + ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); + auto consumer0_c = + ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); + auto consumer1 = + ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); + auto consumer2 = + ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); + auto consumer3 = + ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); + + GraphDef expected_def; + TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); + + GraphDef actual_def; + graph->ToGraphDef(&actual_def); + TF_EXPECT_GRAPH_EQ(expected_def, actual_def); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 5dcf754969f1709bd0e211b456bc634766239980..085c0e5adbb270e71ff3447a936555c99904e26c 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -23,6 +24,11 @@ namespace tensorflow { // PRE_PLACEMENT passes: +// EncapsulateXlaComputationsPass rewrites computations generated by the +// xla.compile() Python code into XlaLaunch nodes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, + EncapsulateXlaComputationsPass); + // from // third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc // FunctionalizeControlFlowPass: 27 @@ -32,7 +38,8 @@ namespace tensorflow { // control flow structure (XlaIf/XlaWhile). Following passes must // handle those FunctionDef correctly. -// POST_REWRITE_FOR_EXEC passes: +// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA: + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); @@ -48,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, // Must run after EncapsulateSubgraphsPass. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, - BuildXlaLaunchOpsPass); + BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 253a5d254792a19d98b75310ea6848f42597c0c7..26cb3af9d69ba1877c67853cde28d2477d394efc 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -7,9 +7,9 @@ package( ) cc_library( - name = "xla_launch_op", - srcs = ["xla_launch_op.cc"], - hdrs = ["xla_launch_op.h"], + name = "xla_ops", + srcs = ["xla_ops.cc"], + hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", @@ -26,6 +26,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc deleted file mode 100644 index b6f2f632f7155234c87a0ea16fdc1910a09ed139..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ /dev/null @@ -1,276 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function) - : OpKernel(ctx), - constants_(constants), - resources_(resources), - device_type_(ctx->device_type()), - function_(function) { - if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = se::host::kHostPlatformId; - } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) { - use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams(); - platform_id_ = xla_device_metadata_->platform()->id(); - } -} - -Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { - if (xla_device_metadata_) { - *cache = new XlaCompilationCache(xla_device_metadata_->client(), - xla_device_metadata_->jit_device_type()); - return Status::OK(); - } - - auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); - if (!platform.ok()) { - return platform.status(); - } - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - device_type_.type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - -void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOpBase::Compute " - << Canonicalize(function_.name(), AttrSlice(&function_.attr())); - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - - XlaCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [this, ctx](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, cache); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - - std::map variables = - SnapshotResourceVariables(ctx, resources_); - - xla::LocalClient* client = static_cast(cache->client()); - - XlaAllocator local_xla_allocator(client->backend().platform(), - ctx->device()->GetAllocator({})); - xla::DeviceMemoryAllocator* xla_allocator; - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which local_xla_allocator above uses) as on an XlaDevice, this is a - // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a - // real allocator to allocate real buffers. - if (xla_device_metadata_) { - xla_allocator = client->backend().memory_allocator(); - } else { - xla_allocator = &local_xla_allocator; - } - - XlaCompiler::Options options; - options.client = client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); - options.device_allocator = xla_allocator; - if (xla_device_metadata_) { - options.shape_representation_fn = - xla_device_metadata_->shape_representation_fn(); - } - - const XlaCompiler::CompilationResult* kernel; - xla::LocalExecutable* executable; - - std::map constant_args; - for (int i : constants_) { - constant_args.insert({i, ctx->input(i)}); - } - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // If we resolve constants we never emit them on the device, meaning that if - // they are needed by a following computation the host has to transfer - // them. Not resolving constants is expected to be faster than resolving - // constants. - compile_options.resolve_compile_time_constants = true; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - - OP_REQUIRES_OK( - ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, compile_options)); - - VLOG(1) << "Executing XLA Computation..."; - - XlaComputationLaunchContext launch_context( - client, xla_allocator, - /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr, - use_multiple_streams_); - launch_context.PopulateInputs(ctx, kernel, variables); - - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(xla_allocator); - run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(GetXLARandomSeed()); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - - auto run_result = executable->Run(launch_context.arguments(), run_options); - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; - - OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie())); - VLOG(1) << "Done"; -} - -namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - -// Helper static functions to construct parameters for -// XlaLocalLaunchBase constructor from OpKernelConstruction. -std::vector ConstantsVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - std::vector constants(constant_types.size()); - std::iota(constants.begin(), constants.end(), 0); - return constants; -} - -std::vector ResourcesVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Tconstants", &constant_types)); - - DataTypeVector arg_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Targs", &arg_types)); - - int num_resources; - OP_REQUIRES_OK_RETURN(ctx, std::vector(), - ctx->GetAttr("Nresources", &num_resources)); - - std::vector resources(num_resources); - std::iota(resources.begin(), resources.end(), - constant_types.size() + arg_types.size()); - return resources; -} - -NameAttrList FunctionAttr(OpKernelConstruction* ctx) { - const NameAttrList* func; - OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); - return *func; -} - -#undef OP_REQUIRES_OK_RETURN -} // namespace - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), - FunctionAttr(ctx)) {} - -XlaLocalLaunchOp::~XlaLocalLaunchOp() { - VLOG(1) << "XlaLocalLaunchOp destroyed"; -} - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch") - .Device(DEVICE_GPU) - .HostMemory("constants") - .HostMemory("resources"), - XlaLocalLaunchOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h deleted file mode 100644 index e0f10e981737ad60e2b785a235dcb7fe7d21a053..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ - -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. -// The only difference is that it does not require arguments to follow -// the "constants, then regular args, then resources" order. -// It takes vectors of constant and resource arguments explicitly. -// It does not have corresponding OpDef because it is never present -// in the GraphDef. -// Currently, it is used by eager runtime. FunctionLibraryRuntime creates -// this kernel when asked to create a kernel for an XLA-compiled function. -class XlaLocalLaunchBase : public OpKernel { - public: - XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector& constants, - const std::vector& resources, - const NameAttrList& function); - XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; - XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; - ~XlaLocalLaunchBase() override = default; - - void Compute(OpKernelContext* ctx) override; - - protected: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache); - - // Indexes of compile-time constant inputs - std::vector constants_; - // Indexes of resource inputs - std::vector resources_; - - DeviceType device_type_; - NameAttrList function_; - se::Platform::Id platform_id_ = nullptr; - bool use_multiple_streams_ = false; - const XlaDevice::Metadata* xla_device_metadata_ = nullptr; -}; - -// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph -// which will be compiled and executed using XLA. The XlaLocalLaunchOp is -// responsible for handling interactions with the TensorFlow executor. -// Once all inputs are present, and their shapes are known, the op can -// use a 'XlaCompilationCache' to compile and execute code which is specific -// to the shapes of input Tensors. -// XlaLocalLaunchOp uses xla::LocalClient::Compile() and -// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device -// memory. -class XlaLocalLaunchOp : public XlaLocalLaunchBase { - public: - explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); - ~XlaLocalLaunchOp() override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..accc86a86d9d3eca741994ee502bd7580ce49b2e --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -0,0 +1,500 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_ops.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +namespace { + +Status PlatformInfoFromContext(OpKernelConstruction* ctx, + XlaPlatformInfo* result) { + DeviceType device_type = ctx->device_type(); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + std::unique_ptr xla_allocator; + xla::DeviceMemoryAllocator* device_allocator = nullptr; + + if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + platform_id = se::host::kHostPlatformId; + } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { + platform_id = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + + platform_id = xla_device_metadata->platform()->id(); + device_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + if (!device_allocator) { + TF_ASSIGN_OR_RETURN(se::Platform* const platform, + se::MultiPlatformManager::PlatformWithId(platform_id)); + xla_allocator = absl::make_unique( + platform, ctx->device()->GetAllocator({})); + } + + *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); + + return Status::OK(); +} + +// A closure describing how to run a compiled version of a TensorFlow function. +// +// It may seem unusual to stick the resource variable snapshots in this class. +// This is necessary: we need to use the snapshots observed by the compiler as +// the initial values for the resource variables (and cannot snapshot them again +// during execution) because otherwise we risk observing a different snapshot +// with shapes different from what we compiled for. +class XlaExecutableClosure { + public: + explicit XlaExecutableClosure( + xla::LocalClient* client, xla::LocalExecutable* executable, + const XlaCompiler::CompilationResult* compilation_result, + std::map resource_var_snapshots, + int num_constant_args) + : client_(client), + executable_(executable), + compilation_result_(compilation_result), + resource_var_snapshots_(std::move(resource_var_snapshots)), + num_constant_args_(num_constant_args) {} + + XlaExecutableClosure(XlaExecutableClosure&&) = default; + XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; + + xla::LocalClient* client() const { return client_; } + xla::LocalExecutable* executable() const { return executable_; } + const XlaCompiler::CompilationResult* compilation_result() const { + return compilation_result_; + } + const std::map& resource_var_snapshots() const { + return resource_var_snapshots_; + } + int num_constant_args() const { return num_constant_args_; } + + private: + xla::LocalClient* client_; + xla::LocalExecutable* executable_; + const XlaCompiler::CompilationResult* compilation_result_; + std::map resource_var_snapshots_; + int num_constant_args_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); +}; + +// This maintains a mapping from a globally unique ID to XlaExecutableClosure +// instances. +class XlaExecutableClosureStore { + public: + XlaExecutableClosureStore() : key_counter_(0) {} + + using KeyT = string; + + KeyT Produce(XlaExecutableClosure result) { + mutex_lock l(mutex_); + KeyT key = absl::StrCat(key_counter_++); + bool insert_successful = closures_.emplace(key, std::move(result)).second; + DCHECK(insert_successful); + (void)insert_successful; + return key; + } + + XlaExecutableClosure Consume(const KeyT& key) { + mutex_lock l(mutex_); + auto it = closures_.find(key); + DCHECK(it != closures_.end()); + XlaExecutableClosure value = std::move(it->second); + closures_.erase(it); + return value; + } + + static XlaExecutableClosureStore* Global() { + static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; + return instance; + } + + private: + mutex mutex_; + int64 key_counter_ GUARDED_BY(mutex_); + absl::flat_hash_map closures_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); +}; + +} // namespace + +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + function_(function) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +static Status BuildCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +static Status CompileToLocalExecutable( + OpKernelContext* ctx, const NameAttrList& function, + const XlaPlatformInfo& platform_info, absl::Span resources, + absl::Span constants, xla::LocalClient** client, + std::map* variables, + const XlaCompiler::CompilationResult** kernel, + xla::LocalExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, platform_info, cache); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + *variables = SnapshotResourceVariables(ctx, resources); + *client = static_cast(cache->client()); + + XlaCompiler::Options options; + options.client = *client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } + options.device_type = cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = platform_info.allocator(); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + + std::map constant_args; + for (int i : constants) { + constant_args.insert({i, ctx->input(i)}); + } + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // If we resolve constants we never emit them on the device, meaning that if + // they are needed by a following computation the host has to transfer + // them. Not resolving constants is expected to be faster than resolving + // constants. + compile_options.resolve_compile_time_constants = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + + return cache->Compile(options, function, constant_args, *variables, ctx, + compile_options, kernel, executable); +} + +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + VLOG(1) << "Executing XLA Computation..."; + + XlaComputationLaunchContext launch_context( + client, platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + platform_info_.UseMultipleStreams()); + launch_context.PopulateInputs(ctx, kernel, variables, + /*missing_ctx_input_prefix=*/0); + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); + VLOG(1) << "Done"; +} + +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + constants_(ConstantsVector(ctx)), + resources_(ResourcesVector(ctx)), + function_(FunctionAttr(ctx)) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaCompileOp::Compute(OpKernelContext* ctx) { + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even + // if it didn't have to compile the cluster because of a compilation-cache + // hit. This is because we at least need new snapshots of the resource + // variables. + XlaExecutableClosureStore::KeyT key = + XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( + client, executable, kernel, std::move(variables), constants_.size())); + + Allocator* cpu_allocator = [&] { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + return ctx->device()->GetAllocator(host_alloc_attrs); + }(); + + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + compilation_key.flat()(0) = key; + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.flat()(0) = true; + + ctx->set_output(0, compilation_key); + ctx->set_output(1, compilation_successful); +} + +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaRunOp::Compute(OpKernelContext* ctx) { + Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); + + XlaExecutableClosure closure = + XlaExecutableClosureStore::Global()->Consume(key); + + XlaComputationLaunchContext launch_context( + closure.client(), platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); + + // We're missing the must-be-constant inputs, tell `PopulateInputs` + // about this. We don't actually need these inputs because they've + // already been baked into the compiled kernel. + launch_context.PopulateInputs( + ctx, closure.compilation_result(), closure.resource_var_snapshots(), + /*missing_ctx_input_prefix=*/closure.num_constant_args()); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = + closure.executable()->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; + + OP_REQUIRES_OK( + ctx, + launch_context.PopulateOutputs( + ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/closure.num_constant_args())); +} + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); +REGISTER_KERNEL_BUILDER(Name("_XlaCompile") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaCompileOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..489d26eb30a66646158f39ea3fc6f55759c7f88e --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::unique_ptr xla_allocator, + xla::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + xla_allocator_(std::move(xla_allocator)), + device_allocator_(device_allocator) { + CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); + } + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + xla::DeviceMemoryAllocator* allocator() const { + return device_allocator_ ? device_allocator_ : xla_allocator_.get(); + } + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator and + // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device + // then device_allocator_ is null and xla_allocator_ points to an appropriate + // XlaAllocator instance. + std::unique_ptr xla_allocator_; + xla::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + XlaPlatformInfo platform_info_; +}; + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'XlaCompilationCache' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. +class XlaLocalLaunchOp : public XlaLocalLaunchBase { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); +}; + +class XlaCompileOp : public OpKernel { + public: + explicit XlaCompileOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + NameAttrList function_; + + XlaPlatformInfo platform_info_; +}; + +class XlaRunOp : public OpKernel { + public: + explicit XlaRunOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + XlaPlatformInfo platform_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index e6cc6e52ae537c23d18dc2d3fb94b40a5d23b1a5..4f0c370e65159c89c91ea58733f20f852d9acc99 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) { return elementwise_ops->count(node.op()) > 0; } +// Nodes that XLA can compile are put in `candidates`. Nodes put in +// `isolated_nodes` must either be unclustered or be put in trivial single-node +// clusters. Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - OrderedNodeSet* candidates) { + OrderedNodeSet* candidates, absl::flat_hash_set* isolated_nodes) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -411,6 +414,8 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); + VLOG(4) << "Device type for " << node->name() << ": " + << device_type.type_string(); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { // is_compilable_fn has already logged the reason if it returned false. @@ -439,19 +444,56 @@ Status FindCompilationCandidates( << node->type_string(); continue; } - if (compile_time_const_nodes[node->id()] && - !registration->requires_compilation) { + if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { - // We need to be able to constant fold the nodes in - // compile_time_const_nodes given constant inputs (required by XLA) and - // therefore can't auto-cluster stateful ops since these can never be - // constant folded. - VLOG(2) << "Rejecting " << node->name() - << ": must-be-constant stateful op"; - continue; + // It is easiest to demonstrate the problem we're trying to solve with + // an example. Say we have this graph: + // + // shape = RandomUniformInt(); + // reshape = Reshape(input, shape) + // + // Both RandomUniformInt and Reshape are compilable by XLA so, absent + // any other reason, we will try to put both shape and reshape in the + // same cluster. However, since XLA only supports statically shaped + // values, it will expect to be able to constant fold `shape` to get a + // static shape for `reshape`. This is a problem because side-effecting + // ops like RandomUniformInt() cannot be constant folded. We fix this + // by putting `shape` and `reshape` in different clusters, which results + // in us recompiling `reshape`'s cluster for every new value of `shape`, + // making `reshape` statically sized within each compilation. We + // simplify the solution even further by disallowing operations like + // `shape` from being part of *any* non-trivial cluster. They're either + // not compiled by XLA altogether or, if assigned to an XLA_* device + // with "must compile" semantics, compiled into a trivial single-op + // cluster. This approach leaves some room for improvement, and we can + // consider implementing a more aggressive data-flow-analysis based + // solution in the future if needed. + // + // One ugly problem we have to contend with: certain sets of ops *have* + // to be in the same cluster because values flowing between them have + // types that can't be live-in or live-out of a cluster. These ops are: + // + // - TensorArray ops operating on the same TensorArray instance. + // - Stack ops operating on the same Stack instance. + // + // To work around this we avoid isolating these specific ops. Because + // of this concession it is unsound to auto-cluster them because then + // we'd create clusters we could not compile (because we can't constant + // fold, say, a TensorArrayRead or a StackPopV2). But we don't + // auto-cluster these operations today so we're good for now. + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node->type_string()); + bool is_tensor_array_or_stack_op = + op_info && op_info->resource_kind() != XlaResourceKind::kVariable; + if (!is_tensor_array_or_stack_op) { + VLOG(2) << "Isolating " << node->name() + << ": must-be-constant stateful op"; + isolated_nodes->insert(node); + // Keep going and execute all the other checks. + } } } // We don't auto-cluster functional control flow nodes containing resource @@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; + absl::flat_hash_set isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env : Env::Default(), - is_compilable_fn, &compilation_candidates)); + is_compilable_fn, &compilation_candidates, &isolated_nodes)); if (compilation_candidates.empty()) { VLOG(2) << "No compilable candidates"; @@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl( "Found control flow node in clustering worklist: ", node_from->type_string()); } + + if (isolated_nodes.count(node_from)) { + continue; + } + string from_scope; string to_scope; for (int to : cycles.Successors(from)) { @@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl( node_to->assigned_device_name()) { continue; } + if (isolated_nodes.count(node_to)) { + continue; + } // Look for an _XlaScope on both nodes. If both nodes have a // scope and the scopes do not match, do not cluster along this // edge. This restriction is overridden if the global_jit_level is ON. If @@ -931,6 +982,11 @@ Status MarkForCompilationPass::RunImpl( // Names for each cluster. std::unordered_map cluster_names; + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, + options.flib_def); + } + // Mark clusters for compilation that: // * are placed on a device that requires compilation (an XlaDevice), // * are explicitly marked for compilation (_XlaCompile=true), or diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index c59770a4c8d4a5cb8508a928677f34aeb3d6acf5..2a80c745e3fcebf97bcccb03551feb3d6fb9f831 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" @@ -61,10 +62,10 @@ std::unordered_map GetClusters(const Graph& graph) { return ids; } -gtl::FlatMap> GetClusterSets( +absl::flat_hash_map> GetClusterSets( const Graph& g, std::vector* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - gtl::FlatMap> cluster_sets; + absl::flat_hash_map> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", @@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector expected_clustered_nodes = {"AssignmentW", @@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::vector cluster_names; - gtl::FlatMap> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 2); @@ -894,5 +895,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { EXPECT_EQ(clusters["fn_call"], ""); } +TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = + ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, + ops::Const(root.WithOpName("test/minval"), 1), + ops::Const(root.WithOpName("test/maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/shape_rng"], ""); + EXPECT_NE(clusters["test/reshape"], ""); + EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); +} + +TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + Scope root = Scope::NewRootScope().ExitOnError(); + ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, + DT_INT32); + Output zero = ops::Const(root.WithOpName("test/zero"), 0); + ops::TensorArrayWrite tensor_array_write( + root.WithOpName("test/write"), tensor_array.handle, zero, + ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); + Output tensor_array_read = + ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, + zero, tensor_array_write.flow_out, DT_INT32); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), + tensor_array_read); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/read"], ""); + EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 65669877f732bad9e145da36a3aedeba611a0fe5..d56d0f8ccfcdab40003be38059228cb255921b64 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,18 +14,35 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, SessionOptions* session_options) { - // Assign all nodes to the CPU device. + // Assign all unassigned nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } + // Call AddDevices to register the XLA devices. + // + // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to + // make this more direct, but probably not worth it solely for this test. + std::vector devices; + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); + + auto delete_devices = gtl::MakeCleanup([&] { + for (Device* d : devices) { + delete d; + } + }); + GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8ace628e6b76e011ecddd4d526efc4db9c9237e --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -0,0 +1,458 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +namespace testing { +namespace matchers { +namespace { + +using impl::NodeMatcherProperties; + +string IndentAllButFirstLine(absl::string_view text) { + std::vector lines = absl::StrSplit(text, '\n'); + for (int i = 1; i < lines.size(); i++) { + lines[i].insert(0, " "); + } + return absl::StrJoin(lines, "\n"); +} + +template +bool CompareTensor(const Tensor& actual, const Tensor& expected, + ::testing::MatchResultListener* listener) { + if (actual.NumElements() != expected.NumElements()) { + if (listener->IsInterested()) { + *listener << "\nwas looking for tensor with " << expected.NumElements() + << " elements, found tensor with " << actual.NumElements() + << " elements"; + return false; + } + } + + for (int64 i = 0, e = actual.NumElements(); i < e; i++) { + if (actual.flat()(i) != expected.flat()(i)) { + *listener << "\nmismatch in constant tensor at index " << i + << " expected = " << expected.flat()(i) + << " actual = " << actual.flat()(i); + return false; + } + } + + return true; +} + +bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, + ::testing::MatchResultListener* listener) { + if (tensor.dtype() != expected_tensor.dtype()) { + if (listener->IsInterested()) { + *listener << "\nexpected tensor of type " + << DataType_Name(expected_tensor.dtype()) + << " but found one of type " << DataType_Name(tensor.dtype()); + return false; + } + } + + switch (tensor.dtype()) { + case DT_FLOAT: + return CompareTensor(tensor, expected_tensor, listener); + case DT_DOUBLE: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_INT64: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT8: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT16: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT32: + return CompareTensor(tensor, expected_tensor, listener); + case DT_UINT64: + return CompareTensor(tensor, expected_tensor, listener); + default: + LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. + << DataType_Name(tensor.dtype()); + } +} + +using Input = std::pair; + +struct NodeMatcher : public ::testing::MatcherInterface { + bool MatchAndExplain( + const Node* node, + ::testing::MatchResultListener* listener) const override { + if (op && node->type_string() != *op) { + if (listener->IsInterested()) { + *listener << "\nexpected op " << *op << " but found " + << node->type_string(); + } + return false; + } + + if (assigned_device && node->assigned_device_name() != *assigned_device) { + if (listener->IsInterested()) { + *listener << "\nexpected assigned_device " << *assigned_device + << " but found \"" << node->assigned_device_name() << "\""; + } + return false; + } + + if (name && node->name() != *name) { + if (listener->IsInterested()) { + *listener << "\nexpected name " << *name << " but found " + << node->name(); + } + return false; + } + + if (constant_value) { + const TensorProto* proto = nullptr; + if (!GetNodeAttr(node->def(), "value", &proto).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find \"value\" attribute in node"; + } + return false; + } + + Tensor tensor(proto->dtype()); + if (!tensor.FromProto(*proto)) { + if (listener->IsInterested()) { + *listener << "\ncould not convert TensorProto in \"value\" attribute " + "to Tensor"; + } + return false; + } + + if (!MatchAndExplainTensor(/*tensor=*/tensor, + /*expected_tensor=*/*constant_value, + listener)) { + return false; + } + } + + if (input_matchers) { + if (input_matchers->size() != node->num_inputs()) { + if (listener->IsInterested()) { + *listener << "\nexpected " << input_matchers->size() + << " inputs but node has " << node->num_inputs(); + } + return false; + } + + for (int input_idx = 0, e = input_matchers->size(); input_idx < e; + input_idx++) { + if (!MatchAndExplainInput(node, input_idx, listener)) { + return false; + } + } + } + + std::vector control_deps; + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) { + control_deps.push_back(e->src()); + } + } + + ::testing::StringMatchResultListener inner_listener; + if (control_dep_set && + !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { + if (listener->IsInterested()) { + string explanation = inner_listener.str(); + if (!explanation.empty()) { + explanation = absl::StrCat(", ", explanation, ","); + } + *listener << "ctrl_deps" << explanation << " does not match expected: "; + control_dep_set->DescribeTo(listener->stream()); + } + return false; + } + return true; + } + + void DescribeTo(::std::ostream* os) const override { + std::vector predicates; + + if (name) { + predicates.push_back(absl::StrCat("name: ", *name)); + } + + if (op) { + predicates.push_back(absl::StrCat("op: ", *op)); + } + + if (assigned_device) { + predicates.push_back(absl::StrCat("assigned device: ", *assigned_device)); + } + + bool printed_something = !predicates.empty(); + + *os << absl::StrJoin(predicates, ", "); + + if (constant_value) { + printed_something = true; + *os << "constant value: " << constant_value->DebugString(); + } + + if (input_matchers) { + if (!input_matchers->empty()) { + printed_something = true; + *os << " with " << (input_matchers->size() == 1 ? "only " : "") + << "input" << (input_matchers->size() == 1 ? "" : "s") << " "; + } + + if (input_matchers->size() == 1) { + ::std::stringstream ss; + input_matchers->front().DescribeTo(&ss); + printed_something = true; + *os << "matching " << ss.str(); + } else { + int edge_idx = 0; + for (const ::testing::Matcher& matcher : (*input_matchers)) { + *os << "\n [" << edge_idx << "] matching ("; + ::std::stringstream ss; + matcher.DescribeTo(&ss); + printed_something = true; + *os << IndentAllButFirstLine(ss.str()); + *os << ")"; + edge_idx++; + } + } + } + + if (control_dep_set) { + printed_something = true; + *os << " and control deps "; + control_dep_set->DescribeTo(os); + } + + if (!printed_something) { + *os << "is any node"; + } + } + + bool MatchAndExplainInput(const Node* node, int input_idx, + ::testing::MatchResultListener* listener) const { + const Edge* edge; + if (!node->input_edge(input_idx, &edge).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find incoming edge for input " << input_idx; + } + return false; + } + + ::testing::StringMatchResultListener inner_listener; + Input input = {edge->src(), edge->src_output()}; + if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { + return true; + } + + if (listener->IsInterested()) { + *listener << "\ninput " << input_idx << " does not match expected:\n"; + (*input_matchers)[input_idx].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + + absl::optional op; + absl::optional name; + absl::optional assigned_device; + absl::optional constant_value; + absl::optional>> input_matchers; + absl::optional<::testing::Matcher>> + control_dep_set; +}; + +// Matches a dst and dst_output on an input edge. Today we only use this with +// dst_output=0 but we will eventually need to support multi-output operations. +class InputMatcher : public ::testing::MatcherInterface { + public: + InputMatcher(::testing::Matcher src_matcher, int src_output) + : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} + + bool MatchAndExplain( + Input input, ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener inner_listener; + if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\nsource does not match expected "; + src_matcher_.DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << "\n\t" << explanation; + } + } + return false; + } + if (input.second != src_output_) { + if (listener->IsInterested()) { + *listener << "\nexpected output slot to be " << src_output_ + << " but found " << input.second; + } + return false; + } + + return true; + } + + void DescribeTo(::std::ostream* os) const override { + if (src_output_) { + *os << "output slot: " << src_output_ << ", source: ("; + } + + src_matcher_.DescribeTo(os); + + if (src_output_) { + *os << ")"; + } + } + + private: + ::testing::Matcher src_matcher_; + int src_output_; +}; + +std::vector<::testing::Matcher> NodeMatchersToInputMatchers( + absl::Span> node_matchers) { + std::vector<::testing::Matcher> result; + absl::c_transform(node_matchers, std::back_inserter(result), + [](::testing::Matcher n) { + return ::testing::MakeMatcher(new InputMatcher(n, 0)); + }); + return result; +} +} // namespace + +::testing::Matcher impl::NodeWith( + absl::Span props) { + NodeMatcher* matcher = new NodeMatcher(); + for (const NodeMatcherProperties& prop : props) { + if (prop.name()) { + DCHECK(!matcher->name); + matcher->name = prop.name(); + } + + if (prop.op()) { + DCHECK(!matcher->op); + matcher->op = prop.op(); + } + + if (prop.constant_value()) { + DCHECK(!matcher->constant_value); + matcher->constant_value = prop.constant_value(); + } + + if (prop.assigned_device()) { + DCHECK(!matcher->assigned_device); + matcher->assigned_device = prop.assigned_device(); + } + + if (prop.input_nodes()) { + DCHECK(!matcher->input_matchers); + matcher->input_matchers = + NodeMatchersToInputMatchers(*prop.input_nodes()); + } + + if (prop.control_deps()) { + DCHECK(!matcher->control_dep_set); + matcher->control_dep_set = + ::testing::UnorderedElementsAreArray(*prop.control_deps()); + } + } + + return ::testing::MakeMatcher(matcher); +} + +impl::NodeMatcherProperties Name(string name) { + impl::NodeMatcherProperties props; + props.set_name(std::move(name)); + return props; +} + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op) { + impl::NodeMatcherProperties props; + props.set_op(std::move(op)); + return props; +} + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device) { + impl::NodeMatcherProperties props; + props.set_assigned_device(std::move(assigned_device)); + return props; +} + +impl::NodeMatcherProperties impl::Inputs( + absl::Span> inputs) { + std::vector<::testing::Matcher> inputs_vector; + absl::c_copy(inputs, std::back_inserter(inputs_vector)); + + impl::NodeMatcherProperties props; + props.set_input_nodes(std::move(inputs_vector)); + return props; +} + +impl::NodeMatcherProperties impl::CtrlDeps( + absl::Span> control_deps) { + std::vector<::testing::Matcher> control_deps_vector; + absl::c_copy(control_deps, std::back_inserter(control_deps_vector)); + + impl::NodeMatcherProperties props; + props.set_control_deps(std::move(control_deps_vector)); + return props; +} + +NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val) { + TF_CHECK_OK(val.status); + NodeMatcherProperties props; + props.set_constant_value(val.tensor); + return props; +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val) { + return NodeWith(ConstantValue(val)); +} +} // namespace matchers + +Node* FindNodeByName(Graph* g, absl::string_view name) { + for (Node* n : g->nodes()) { + if (n->name() == name) { + return n; + } + } + + return nullptr; +} +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..0437a7e95c1eb3bdcdbe24a440dd90a5943c0894 --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.h @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a set of matchers for tensorflow nodes. +// +// Example usage: +// +// tensorflow::Node* node = ...; +// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), +// Inputs(NodeWith(Name("input"))))) +// +// Matchable node properties (the expressions that go inside NodeWith(...)) +// are: +// +// - Name(string): matches the node name exactly. We will probably need to +// have this take a string matcher soon in the future. +// +// - Op(string): matches the op exactly. +// +// - AssignedDevice(string): matches the assigned device exactly. +// +// - Inputs(): matches the list of non-control inputs to the node +// exactly (i.e. does not match a suffix or a prefix). +// +// - CtrlDeps(): matches the list of control dependences on the +// node exactly but in any order. +// +// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node +// with the constant value `init`. Implies Op("Const"). +// +// Node properties may not be repeated in a single NodeWith(...) matcher. +// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue +// implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). + +#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ +#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace testing { +namespace matchers { + +namespace impl { + +// ----------------------------------------------------------------------------- +// Implementation details. + +// Properties that we match on for a particular Node. If a particular property +// is nullopt then any value for it is allowed. +class NodeMatcherProperties { + public: + using NodeSeqMatcher = std::vector<::testing::Matcher>; + + const absl::optional& name() const { return name_; } + const absl::optional& op() const { return op_; } + const absl::optional& assigned_device() const { + return assigned_device_; + } + const absl::optional& constant_value() const { + return constant_value_; + } + const absl::optional& input_nodes() const { + return input_nodes_; + } + const absl::optional& control_deps() const { + return control_deps_; + } + + void set_name(string name) { + DCHECK(IsEmpty()); + name_ = std::move(name); + } + + void set_op(string op) { + DCHECK(IsEmpty()); + op_ = std::move(op); + } + + void set_assigned_device(string assigned_device) { + DCHECK(IsEmpty()); + assigned_device_ = std::move(assigned_device); + } + + void set_constant_value(Tensor constant_value) { + DCHECK(IsEmpty()); + constant_value_ = std::move(constant_value); + op_ = "Const"; + } + + void set_input_nodes(NodeSeqMatcher input_nodes) { + DCHECK(IsEmpty()); + input_nodes_ = std::move(input_nodes); + } + + void set_control_deps(NodeSeqMatcher control_deps) { + DCHECK(IsEmpty()); + control_deps_ = std::move(control_deps); + } + + bool IsEmpty() const { + return !name().has_value() && !op().has_value() && + !input_nodes().has_value() && !control_deps().has_value(); + } + + private: + absl::optional name_; + absl::optional op_; + absl::optional assigned_device_; + absl::optional constant_value_; + absl::optional input_nodes_; + absl::optional control_deps_; +}; + +::testing::Matcher NodeWith( + absl::Span props); + +impl::NodeMatcherProperties Inputs( + absl::Span> inputs); + +impl::NodeMatcherProperties CtrlDeps( + absl::Span> control_deps); +} // namespace impl + +// ----------------------------------------------------------------------------- +// Public interface. + +// Matches a node with name `name`. +impl::NodeMatcherProperties Name(string name); + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op); + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device); + +// Matches a node with inputs `inputs`. +// +// `inputs` are ordered; `inputs`[i] must match input i. +template +impl::NodeMatcherProperties Inputs(Ts... inputs) { + return impl::Inputs({inputs...}); +} + +// Matches a node with control dependences `control_deps`. +// +// `control_deps` are unordered and will match the control deps of a node in any +// order. +template +impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { + return impl::CtrlDeps({control_deps...}); +} + +// Matches a constant node with value `val`. +impl::NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val); + +// The main gmock matcher. See file comment for example usage. +template +::testing::Matcher NodeWith(Ts... args) { + std::array array = {args...}; + return impl::NodeWith(array); +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val); +} // namespace matchers + +// If `g` has a node named `name` returns it, otherwise returns null. +Node* FindNodeByName(Graph* g, absl::string_view name); +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93a8994307b38ac240c22d0a18268638ac7620ae --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" + +namespace tensorflow { +namespace testing { +namespace { + +using ::testing::_; + +using testing::matchers::AssignedDevice; +using testing::matchers::ConstantValue; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; + +template +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(NodeMatchers, CheckAgainstConstant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Name("placeholder"), Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Inputs())); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"), Inputs())); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))), + "\nexpected op Add but found Placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), + "\nexpected name add but found placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))), + "\nexpected 1 inputs but node has 0"); +} + +TEST(NodeMatchers, CheckAgainstBinary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); + + EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"), + Inputs(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + + EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), + "\nexpected 0 inputs but node has 2"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))), + "\ninput 0 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_a"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))), + "\ninput 1 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_b"); +} + +TEST(NodeMatchers, CheckControlDependence) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output placeholder_c = + ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT); + Output placeholder_d = + ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT); + + root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node()); + root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node()); + + EXPECT_THAT(placeholder_c.node(), + NodeWith(Name("placeholder_c"), + CtrlDeps(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(placeholder_d.node(), + NodeWith(Name("placeholder_d"), CtrlDeps())); + + EXPECT_EQ( + Explain(placeholder_c.node(), NodeWith(CtrlDeps())), + "ctrl_deps, which has 2 elements, does not match expected: is empty"); + EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))), + "ctrl_deps does not match expected: has 1 element and that element " + "is any node"); +} + +TEST(NodeMatchers, ConstVaulue) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + Output const_0d = ops::Const(root.WithOpName("const_0d"), 42); + + Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}}); + + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42))); + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d"))); + + EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))), + "\nexpected op Const but found Placeholder"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue(43))), + "\nmismatch in constant tensor at index 0 expected = 43 actual = 42"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))), + "\nwas looking for tensor with 4 elements, found tensor with 1 elements"); + EXPECT_EQ( + Explain(const_2d.node(), NodeWith(ConstantValue(42))), + "\nwas looking for tensor with 1 elements, found tensor with 4 elements"); +} + +TEST(NodeMatchers, AssignedDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + + Output assigned_add = + ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b); + assigned_add.node()->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:CPU:0"); + + Output unassigned_add = + ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b); + + EXPECT_THAT( + assigned_add.node(), + NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0"))); + EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice(""))); + + EXPECT_EQ(Explain(unassigned_add.node(), + NodeWith(AssignedDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"))), + "\nexpected assigned_device " + "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); +} + +} // namespace +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 13804c6a0575b921839f99ef7d142e0871693b5a..f72224545b25bc7100e0b6788e6fbf0a7ca63dad 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -4,9 +4,17 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + cc_library( name = "xla_ops", srcs = ["xla_ops.cc"], deps = ["//tensorflow/core:framework"], alwayslink = 1, ) + +tf_gen_op_wrapper_py( + name = "xla_ops_wrapper_py", + out = "xla_ops.py", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index f2473d98ffd5dae55983e601b8d2d65af6a6d54c..bcd1a29b1ff789b5674a21ff66cc6d23a809afc5 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -32,4 +36,58 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); +REGISTER_OP("XlaClusterOutput") + .Input("input: T") + // Note: when replication is supported, this op will have N outputs. + .Output("outputs: T") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an XLA computation to other " + "consumer graph nodes."); + +REGISTER_OP("_XlaCompile") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Input("resources: Nresources * resource") + .Attr("Nresources: int >= 0") + .Output("key: string") + .Output("compilation_successful: bool") + .Attr("function: func") + // The compilation cache is stateful. + .SetIsStateful() + .Doc(R"(XLA Compile Op. For use by the XLA JIT only. + +Compiles a TensorFlow function into an XLA LocalExecutable and returns a key +that _XlaRun can use to look up the LocalExecutable and execute it. + +key: A key that can be used to look up the local executable compiled by the + node and associated metadata. + +compilation_successful: True iff the compilation was successful. Always true +for now. +)"); + +REGISTER_OP("_XlaRun") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Input("key: string") + // XLA random-number generation ops are stateful. + // TODO(phawkins): create stateful and non-stateful variants of _XlaRun. + .SetIsStateful() + .Doc(R"(XLA Run Op. For use by the XLA JIT only. + +Executes a TensorFlow function previously compiled into a LocalExecutable by an +_XlaCompile op. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 10fc9e85d927ffe2416d6d9e6dfd24b286fbf1a0..b1f9e9088f391cb8813d2c82395ffcc0b2081cae 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -15,17 +15,18 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set* result, absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to @@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/NotBackedge); - gtl::FlatSet nodes_to_partially_decluster; + absl::flat_hash_set nodes_to_partially_decluster; TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 35872daa658810707c12fb5020ee6d913167946b..0feb73a89e7050e8c413e5a733da1d87775b0ba3 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -60,9 +60,9 @@ class FakeBinaryOp : public OpKernel { void Compute(OpKernelContext* ctx) override { CHECK(false); } }; -class FakeResourceVarUpdateOp : public OpKernel { +class FakeResourceUpdateOp : public OpKernel { public: - explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + explicit FakeResourceUpdateOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { CHECK(false); } @@ -74,10 +74,9 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary") .HostMemory("host_out"), FakeBinaryOp); -REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") - .Device(DEVICE_CPU) - .HostMemory("something_else"), - FakeResourceVarUpdateOp); +REGISTER_KERNEL_BUILDER( + Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"), + FakeResourceUpdateOp); Status PartiallyDecluster(std::unique_ptr* graph) { FixupSourceAndSinkEdges(graph->get()); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 56e35c0059124015266ffabdf583c8724c8e0908..e039d46ec863920eb7deb5bc20525fdab866415c 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,8 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet; + using Impl = absl::flat_hash_set; public: ResourceOpSet() = default; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3aa9e9c7ed2dd3b7480f40e868c6b07192b68294..0471995015bb080016b523305c90a3e42163a039 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -228,37 +228,38 @@ Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { return CompileImpl(options, function, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, false); + compile_options, /*compile_single_op=*/false, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl(options, name, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, true); + return CompileImpl( + options, name, constant_args, variable_args, ctx, compile_options, + /*compile_single_op=*/true, out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op) { - CHECK_NE(executable, nullptr); + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -357,8 +358,8 @@ Status XlaCompilationCache::CompileImpl( } } TF_RETURN_IF_ERROR(entry->compilation_status); - *compilation_result = &entry->compilation_result; - *executable = entry->executable.get(); + *out_compilation_result = &entry->compilation_result; + *out_executable = entry->executable.get(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 10ad87e38cc4d614e869782329f84351bc3b1f0b..75c7758f730f9f2f8251c02e7fac1a01f8cc9c2b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -68,9 +68,9 @@ class XlaCompilationCache : public ResourceBase { const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -78,9 +78,9 @@ class XlaCompilationCache : public ResourceBase { const XlaCompiler::Options& options, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -89,15 +89,14 @@ class XlaCompilationCache : public ResourceBase { private: // Common implementation of Compile and CompileSingleOp. - Status CompileImpl(const XlaCompiler::Options& options, - const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op); + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompileOptions& compile_options, + bool compile_single_op, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. @@ -152,7 +151,7 @@ class XlaCompilationCache : public ResourceBase { }; mutex compile_cache_mu_; - gtl::FlatMap, Signature::Hash> cache_ + absl::flat_hash_map, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); struct CompileStats { @@ -165,7 +164,7 @@ class XlaCompilationCache : public ResourceBase { mutex compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - gtl::FlatMap compile_stats_ + absl::flat_hash_map compile_stats_ GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 3ba48e8c318f84a4691fb74434bc009fdd0d81bf..79976c85dff200ce993ebb06e7a20a15b71f6085 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -34,6 +34,7 @@ std::map GetVariables(OpKernelContext* ctx) { OptionalTensor& optional = variables[i]; optional.name = handle.name(); if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); optional.present = true; optional.value = *variable->tensor(); @@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables); + launch_context.PopulateInputs(ctx, result, variables, + /*missing_ctx_input_prefix=*/0); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, TF_RETURN_IF_ERROR(run_result.status()); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie())); + ctx, result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); return Status::OK(); } @@ -177,7 +180,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, compile_options); + compile_options, result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7e159e3171113b0d53f03bb676ac9c21db7fe77a..003c1d8081a3313fd042cdcaea14508ed1048da3 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" @@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 51797def041d5d223d22fb28408ec91290a1400d..0824c4644e3e5d8e1390b99f12de824bfcdfec24 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -373,7 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - TracingDevice::Compute(op_kernel, context); + op_kernel->Compute(context); } void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, @@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } +void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { + mutex_lock lock(mu_); + sync_on_completion_ = sync_on_completion; +} + +bool XlaDevice::RequiresSyncOnCompletion() const { + mutex_lock lock(mu_); + return sync_on_completion_; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 92891ffa8c6e4a19623172574b17d90fd344c570..0f06b3fc80b7c844dae5643127bdabba8a53b35e 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice { // information for GPU and TPU devices. Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + // Instructs this XlaDevice to return 'sync_on_completion' for + // RequiresSyncOnCompletion(). + void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + + bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + private: xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) @@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); - mutex mu_; + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. @@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; + + // True if the device requires XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ GUARDED_BY(mu_) = false; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 49c85826829fb44d58f10e084f8d757d65bf1882..6967ad1f03fb5dd962d5b41f0c7ab1dfa42fab94 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -65,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); + #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ @@ -89,9 +99,15 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ + ResourceHandlesOp); \ REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ + ReadVariablesOp); \ REGISTER_KERNEL_BUILDER( \ Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ DestroyResourceOp); \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ef4466f0056ea98adc1ae6774105466af0d14293..60979556a3245f4a9984cde889835ce31154fe18 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 45745596749207189c60ee1e3dcf19b6ecb7eb5b..19e681af0c940023de2ce82b3b337babe2f3dd5a 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -15,7 +15,7 @@ limitations under the License. // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; } REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, kExecAllTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, + kExecAllTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index affeab4a8c43b63ac0e2b8ef40de5223ce39d410..4f6fc4e068e3ba125ddbca264c1affa1f09f5896 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -42,13 +42,14 @@ using xla::ShapedBuffer; } // anonymous namespace std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables) { + OpKernelContext* ctx, absl::Span variables) { std::map snapshot; for (int i : variables) { Var* variable = nullptr; ResourceHandle handle = HandleFromInput(ctx, i); OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); tensor.present = true; @@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables) { + const std::map& variables, + int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. @@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs( const Tensor* t; for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; + DCHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; if (variables.count(arg_num)) { t = &(variables.at(arg_num).value); CHECK(t); } else { - t = &(ctx->input(arg_num)); + t = &(ctx->input(arg_num - missing_ctx_input_prefix)); } if (use_multiple_streams_) { @@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs( Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output) { + ScopedShapedBuffer output, int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -275,6 +278,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " << DataTypeString(type); if (type == DT_RESOURCE) { + TF_RET_CHECK(kernel->outputs[i].input_index >= 0) + << "Invalid input for outputs " << i; ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); @@ -313,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } @@ -323,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. TF_RETURN_IF_ERROR(LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), &variable, + ctx, HandleFromInput(ctx, actual_input_index), &variable, [&write](Var** ptr) { *ptr = new Var(write.type); return Status::OK(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 7ac275fab833400b90ced0180192845c9be30534..326d70a027564343408df356833c97e131495da0 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { class XlaAllocator; @@ -43,7 +44,7 @@ class XlaAllocator; // resource variable is not initialized, the corresponding OptionalTensor // will have its `present` field set to false. std::map SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector& variables); + OpKernelContext* ctx, absl::Span variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -88,14 +89,24 @@ class XlaComputationLaunchContext { // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map& variables); + const std::map& variables, + int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. Status PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + xla::ScopedShapedBuffer output, + int missing_ctx_input_prefix); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index e7623582f62486e1b992b0bbafe6862e64f27ae4..ba2401ed2628beeba2be3bf59a067c3d87ca3f9f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -277,9 +277,10 @@ tf_xla_py_test( ], ) +# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "medium", + size = "large", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -893,6 +894,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "tensor_list_ops_test", + size = "small", + srcs = ["tensor_list_ops_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", @@ -977,7 +994,7 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], - tags = ["noasan"], # times out, http://b/78599043 + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1027,6 +1044,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "permute_test", + size = "small", + srcs = ["permute_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn_ops", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", @@ -1104,6 +1134,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1197,8 +1228,21 @@ tf_xla_py_test( ) tf_xla_py_test( - name = "xla_ops_test", + name = "quantized_ops_test", size = "small", + srcs = ["quantized_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "xla_ops_test", + size = "medium", srcs = ["xla_ops_test.py"], disabled_backends = ["cpu_ondemand"], deps = [ diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 4155342787fbbdeaf5c5958c44d007b1ea0660ed..68f52e796c283997b71abcdb9c3bd6aa19cb06fc 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 17280e445b329d1541aaed78ec106f8f282cbc74..1b39d53dc0908e1fa05f766ca1e601731b26846d 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase): equality_test=self.ListsAreClose) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testBinary( gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), @@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - if dtype not in self.complex_types: # min/max not supported for complex + # min/max not supported for complex + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.maximum, np.array([1, 2], dtype=dtype), @@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[70], [14]], dtype=dtype)) # Complex support for squared_difference is incidental, see b/68205550 - if dtype not in self.complex_types: + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.squared_difference, np.array([1, 2], dtype=dtype), @@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1) + divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( gen_math_ops.floor_div, @@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testDivision(dtype) def testFloatDivision(self): @@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 1, -1, 0], dtype=dtype)) def testIntRemainder(self): - for dtype in self.int_types: + for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) def testFloatRemainder(self): @@ -1437,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 0], dtype=np.int32), expected=np.zeros([4, 0], dtype=dtype)) + x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) + self._testBinary( + array_ops.broadcast_to, + x, + np.array((3, 7, 8, 9), dtype=np.int32), + expected=np.tile(x, (1, 7, 8, 9))) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 7b114d4f85d3a5cadc6af25b55c5a21f90d2a768..1d3979b21bfd915a641fabe1ef40301b3e5a17b4 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -2,90 +2,103 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) def all_backends(): - b = ["cpu"] + plugins.keys() - if cuda_is_configured(): - return b + ["gpu"] - else: - return b + b = ["cpu"] + plugins.keys() + if cuda_is_configured(): + return b + ["gpu"] + else: + return b -def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, - disabled_backends=None, **kwargs): - """Generates py_test targets, one per XLA backend. +def tf_xla_py_test( + name, + srcs = [], + deps = [], + tags = [], + data = [], + main = None, + disabled_backends = None, + **kwargs): + """Generates py_test targets, one per XLA backend. - This rule generates py_test() targets named name_backend, for each backend - in all_backends(). The rule also generates a test suite with named `name` that - tests all backends for the test. + This rule generates py_test() targets named name_backend, for each backend + in all_backends(). The rule also generates a test suite with named `name` that + tests all backends for the test. - For example, the following rule generates test cases foo_test_cpu, - foo_test_gpu, and a test suite name foo_test that tests both. - tf_xla_py_test( - name="foo_test", - srcs="foo_test.py", - deps=[...], - ) + For example, the following rule generates test cases foo_test_cpu, + foo_test_gpu, and a test suite name foo_test that tests both. + tf_xla_py_test( + name="foo_test", + srcs="foo_test.py", + deps=[...], + ) - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - tags: Tags to apply to the generated targets. - data: Data dependencies of the target. - main: Same as py_test's main attribute. - disabled_backends: A list of backends that should not be tested. Supported - values include "cpu" and "gpu". If not specified, defaults to None. - **kwargs: keyword arguments passed onto the generated py_test() rules. - """ - if disabled_backends == None: - disabled_backends = [] + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + tags: Tags to apply to the generated targets. + data: Data dependencies of the target. + main: Same as py_test's main attribute. + disabled_backends: A list of backends that should not be tested. Supported + values include "cpu" and "gpu". If not specified, defaults to None. + **kwargs: keyword arguments passed onto the generated py_test() rules. + """ + if disabled_backends == None: + disabled_backends = [] - enabled_backends = [b for b in all_backends() if b not in disabled_backends] - test_names = [] - for backend in enabled_backends: - test_name = "{}_{}".format(name, backend) - backend_tags = ["tf_xla_{}".format(backend)] - backend_args = [] - backend_deps = [] - backend_data = [] - if backend == "cpu": - backend_args += [ - "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" - ] - elif backend == "gpu": - backend_args += [ - "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" - ] - backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_args += ["--test_device=" + plugins[backend]["device"], - "--types=" + plugins[backend]["types"]] - backend_tags += plugins[backend]["tags"] - backend_args += plugins[backend]["args"] - backend_deps += plugins[backend]["deps"] - backend_data += plugins[backend]["data"] - else: - fail("Unknown backend {}".format(backend)) + enabled_backends = [b for b in all_backends() if b not in disabled_backends] + test_names = [] + for backend in enabled_backends: + test_name = "{}_{}".format(name, backend) + backend_tags = ["tf_xla_{}".format(backend)] + backend_args = [] + backend_deps = [] + backend_data = [] + if backend == "cpu": + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + ] + elif backend == "gpu": + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + ] + backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_args += [ + "--test_device=" + plugins[backend]["device"], + "--types=" + plugins[backend]["types"], + ] + backend_tags += plugins[backend]["tags"] + backend_args += plugins[backend]["args"] + backend_deps += plugins[backend]["deps"] + backend_data += plugins[backend]["data"] + else: + fail("Unknown backend {}".format(backend)) - native.py_test( - name=test_name, - srcs=srcs, - srcs_version="PY2AND3", - args=backend_args, - main="{}.py".format(name) if main == None else main, - data=data + backend_data, - deps=deps + backend_deps, - tags=tags + backend_tags, - **kwargs - ) - test_names.append(test_name) - native.test_suite(name=name, tests=test_names) + native.py_test( + name = test_name, + srcs = srcs, + srcs_version = "PY2AND3", + args = backend_args, + main = "{}.py".format(name) if main == None else main, + data = data + backend_data, + deps = deps + backend_deps, + tags = tags + backend_tags, + **kwargs + ) + test_names.append(test_name) + native.test_suite(name = name, tests = test_names) -def generate_backend_suites(backends=[]): - """Generates per-backend test_suites that run all tests for a backend.""" - if not backends: - backends = all_backends() - for backend in backends: - native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) +def generate_backend_suites(backends = []): + """Generates per-backend test_suites that run all tests for a backend.""" + if not backends: + backends = all_backends() + for backend in backends: + native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend]) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 37e5318bb54c5d8ecdedc7bb346e89765f2adf35..2d225ad226cac368042b95eae8fc29e6fd8e82e0 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase): ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"): array_ops.concat([scalar, scalar, scalar], dim) + # The purpose of this is to ensure that XLA on GPU will not run out of memory + # with too many arguments. + def testConcatLargeNumberOfTensors(self): + with self.cached_session(): + with self.test_scope(): + for concat_dim in range(2): + params = {} + p = [] + shape = np.array([7, 13]) + num_tensors = 1001 + for i in np.arange(num_tensors): + input_shape = shape + placeholder = array_ops.placeholder( + dtypes.float32, shape=input_shape) + p.append(placeholder) + params[placeholder] = np.random.rand(*input_shape).astype( + np.float32) + + concat_inputs = p + c = array_ops.concat(concat_inputs, concat_dim) + result = c.eval(feed_dict=params) + + self.assertEqual(result.shape, c.get_shape()) + cur_offset = 0 + + for i in np.arange(num_tensors): + # The index into the result is the ':' along all dimensions + # except the concat_dim. slice(0, size) is used for ':', and + # a list of slices is used to index into result. + index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)] + index[concat_dim] = slice( + cur_offset, cur_offset + params[p[i]].shape[concat_dim]) + cur_offset += params[p[i]].shape[concat_dim] + self.assertAllEqual(result[index], params[p[i]]) + class ConcatOffsetTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 0af74c2d8f243d8f5ccf1373e0706039cc8ef041..9390870e07d6b5bd90dbc5c04bac0946595dcf7f 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -45,17 +45,21 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def XlaLaunchOpCount(labels): - """Count how many XlaLaunch labels are present.""" - return sum("XlaLaunch(" in x for x in labels) +class DenseLayerTest(test.TestCase): + def countXlaOps(self, labels): + """Count how many XlaCompile/XlaRun labels are present.""" + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + return xla_run_count -class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaCompile/XlaRun op pair in + auto-jit mode. """ os.environ["TF_XLA_FLAGS"] = ( @@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - XlaLaunch op by XLA. + XlaCompile/XlaRun op pair by XLA. """ with self.cached_session() as sess: @@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) # No need to check whether ListDiff is compiled or not because ListDiff op # is not used when input tensor shape is fully defined. @@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op + pairs. """ with self.cached_session() as sess: @@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertEqual(2, self.countXlaOps(labels)) self.assertFalse(InLabels(labels, "MatMult")) diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 8c018cccb83a05babb0b7f73b80b4f9de7267c98..374942a0b339b816944ea5529e4f84134b60017b 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): @@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] @@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): self._testLearning(False, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearningWithGradientChecker(self, data_format): self._testLearning(True, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 089d95daab7e502b4ba13796fadc2ba3f209759b..a38e1edafe883f6d3b64e1d7f94e394cccafa2e9 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase): indices_tf = constant_op.constant(indices) gather_t = array_ops.gather(params, indices_tf) gather_val = session.run(gather_t, feed_dict={params: params_np}) - np_val = params_np[indices] + np_val = constant_op.constant(params_np[indices]) self.assertAllEqual(np_val, gather_val) def testScalar2D(self): @@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant(2) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, 2, axis=axis) + expected = constant_op.constant( + np.take(params_np, 2, axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): @@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant([0, 1, 0, 2]) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32_Int64Indices(self): @@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase): params: params_np, indices: indices_np }) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testHigherRank(self): @@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase): tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, axis=axis) + gather_np = constant_op.constant( + np.take(params, indices, axis=axis), dtype) self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 6fe5a66e0e6717ec738dded9196eef6ba1e2114d..68fdb5caf4c2a496b5058cdda40ca650484a6e0e 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -605,10 +605,6 @@ class ResizeBilinearTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): def testNMS128From1024(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -644,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -693,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] @@ -736,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + def testNMS3Then1WithScoreMaxThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + # One is filtered out by max_output_size. + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 0839fb123e83960e198eac2bed769afbdd517889..de68ff0e32cd59e65094c0b7319f8ab213eed4db 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -77,11 +77,11 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" +def MetadataHasXlaOp(run_metadata): + """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaRun") class JitLaunchTest(test.TestCase): @@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node - # actually ran. However, it is sometimes possible for XlaLaunch ops to be - # constant-folded away, so the check is optional. + # + # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun + # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun + # ops to be constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] @@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase): print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase): y = math_ops.add(x, x) return y, y - # Exercises compling a function (say, Foo) which calls another - # function (say, Bar) which is not inlined. When the compiler compiles - # Foo, it needs to symbolic execute Bar correctly regardless whether - # Bar is inlined or not. + # Exercises compiling a function (say, Foo) which calls another function + # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs + # to symbolically execute Bar correctly regardless of whether Bar is inlined + # or not. # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. @@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase): # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): @@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): @@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): @@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaCompile")) + self.assertFalse(InLabels(labels, "XlaRun")) - # Compile the backprop. One XlaLaunch. + # Compile the backprop. One XlaCompile/XlaRun pair. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaCompile")) + self.assertTrue(InLabels(labels, "XlaRun")) class ElementWiseFusionTest(test.TestCase): @@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("XlaLaunch(" in x for x in labels) - return output, count + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + + return output, xla_run_count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py index 43c469d0320645cdad6ddc67f3e8cb1374b8e9e5..73b3638e801e7389e83953f6662bcfc78ad86203 100644 --- a/tensorflow/compiler/tests/lstm.py +++ b/tensorflow/compiler/tests/lstm.py @@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq): def RandomVar(shape, name=None): """Returns a variable of the given shape initialized to random values.""" - return variables.Variable( + return variables.VariableV1( random_ops.random_uniform(shape), dtype=dtypes.float32, name=name) diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index f985c5d2d96e06fc0117f3935d61b19c9e8562b1..38cb2f83efc48ffcdf5403a23e666963b2ea4da1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase): output.run() def testConstants(self): - constants = [ - np.float32(42), - np.array([], dtype=np.float32), - np.array([1, 2], dtype=np.float32), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), - np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], - dtype=np.float32), - np.array([[[]], [[]]], dtype=np.float32), - np.array([[[[1]]]], dtype=np.float32), - ] - for c in constants: - self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + for dtype in self.numeric_types: + constants = [ + dtype(42), + np.array([], dtype=dtype), + np.array([1, 2], dtype=dtype), + np.array([7, 7, 7, 7, 7], dtype=dtype), + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + def testComplexConstants(self): + for dtype in self.complex_types: + constants = [ + dtype(42 + 3j), + np.array([], dtype=dtype), + np.ones([50], dtype=dtype) * (3 + 4j), + np.array([1j, 2 + 1j], dtype=dtype), + np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4 + 6j], [5, 6]], + [[10 + 7j, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1 + 3j]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb9274df4f579fbc6076bf55c9307e4d1cb7768 --- /dev/null +++ b/tensorflow/compiler/tests/permute_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the DataFormatVecPermute operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaPermuteOpTest(xla_test.XLATestCase): + + def _runPermuteAndCompare(self, x, src_format, dst_format, expected): + with self.cached_session() as session: + with self.test_scope(): + placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) + param = {placeholder: x} + output = nn_ops.data_format_vec_permute( + placeholder, src_format=src_format, dst_format=dst_format) + result = session.run(output, param) + self.assertAllEqual(result, expected) + + def testNHWCToNCHW(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", + [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", + [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", + [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", + [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..80c338513bc9ff6b8e56c5ad6b904af9e06a3715 --- /dev/null +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for quantized operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class QuantizedOpsTest(xla_test.XLATestCase): + + # Verify that quantized types can be clustered by XLA. + def testQuantizedTypeRoundtrip(self): + with self.cached_session() as session: + for dtype in self.quantized_tf_types: + in_values = np.array([1, 2, 3, 4, 5, 6]) + expected = [[1, 2], [3, 4], [5, 6]] + with self.test_scope(): + p = array_ops.placeholder(dtype=dtypes.int32) + x = math_ops.cast(p, dtype) + x = array_ops.reshape(x, [3, 2]) + + value = session.run(x, {p: in_values}) + self.assertAllEqual(value, expected) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 6e183441179ebf2e8c063b333f9328d6fa86cc88..36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): - return set(self.numeric_types) - set(self.complex_types) + return set(self.numeric_types) - set( + self.complex_types) - {np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): for dtype in self._random_types(): @@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - self._testRngIsNotConstant(rng, dtypes.float32) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testTruncatedNormalIsInRange(self): count = 10000000 - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + # TODO(b/34339814): make this test work with 16 bit float types. + for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) @@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): - # TODO(b/26783907): this test requires the CPU backend to implement sort. - if self.device in ["XLA_CPU"]: - return with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index bddda6f30245d4b8281a77783ec9922d61bd3883..dc119fb0f8a41a3772a8c9508bf2db657f57de88 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" @@ -63,7 +64,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) -> float { float generated; @@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_DOUBLE: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); test::FillFn(&tensor, [&](int i) -> double { double generated; @@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_COMPLEX64: { - gtl::FlatSet> already_generated; + absl::flat_hash_set> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); test::FillFn(&tensor, [&](int i) { complex64 generated; @@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT32: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); test::FillFn(&tensor, [&](int i) -> int32 { int32 generated; @@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT64: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); test::FillFn(&tensor, [&](int i) -> int64 { @@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_BOOL: { - gtl::FlatSet already_generated; + absl::flat_hash_set already_generated; std::bernoulli_distribution distribution; test::FillFn(&tensor, [&](int i) -> bool { bool generated; @@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) { do { dims = RandomDims(1); size = TensorShape(dims).num_elements(); - } while (size * size < tf_xla_max_tensor_size); + } while (size * size > tf_xla_max_tensor_size); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 60c2337743b44e9bad61c4d65280eb2b1a1ad9ea..abc822ef363e5d83c99bb963582662ccfce4cd6d 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): def testSeqLength(self): for dtype in self.all_types: - for seq_dtype in self.int_types: + for seq_dtype in self.all_types & {np.int32, np.int64}: self._testBasic(dtype, seq_dtype) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 51c04b5c4796474700a92a8b23a1cbdf533fcbb4..57f0ab7a9eae16ab3de61af9760dfba1ab355b46 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,22 +48,30 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) - def testTopK(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return + def testKeyValueSort(self): + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for key_type in supported_types.intersection(self.numeric_types): + for value_type in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -89,10 +97,6 @@ class XlaSortOpTest(xla_test.XLATestCase): expected=[x[indices].astype(dtype), indices]) def testTopK2D(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -122,10 +126,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: @@ -144,10 +144,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 1bea7d9355e40c5a71f848dabc0fa7fa760429d2..e8741bc468585ff9fb049dcd87700f8048d74026 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): - return [dtypes.float32] + return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) @@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - x = stateless.stateless_random_uniform( + x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) @@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 @@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c079d595c440cac644f5461154509abe7b1d1ed --- /dev/null +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ops which manipulate lists of tensors via bridge.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +def scalar_shape(): + return ops.convert_to_tensor([], dtype=dtypes.int32) + + +class ListOpsTest(xla_test.XLATestCase): + + def testElementShape(self): + with self.cached_session() as sess, self.test_scope(): + dim = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(dim, 15), num_elements=20, + element_dtype=dtypes.float32) + e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) + e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) + self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) + self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) + + def testPushPop(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + + def testPushPopSeparateLists(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=scalar_shape(), + num_elements=num, + element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) + _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) + + def testEmptyTensorList(self): + dim = 7 + with self.cached_session() as sess, self.test_scope(): + p = array_ops.placeholder(dtypes.int32) + l = list_ops.empty_tensor_list( + element_shape=(p, 15), element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(dim, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Use TensorListReserve instead"): + self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 55a992195f2df72677b77757ae86171fa662439f..98a07709c611178effd7794ba58ba89770c6d77f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase): expected=np.array([[2], [5]], dtype=dtype)) def testClipByValue(self): - # TODO(b/78258593): enable integer types here too. - for dtype in self.float_types: + for dtype in self.numeric_types - self.complex_types: test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5b0e57f83ff4b5a8d1891bef0675074bd67addce..77f6eee0cf8ddc9b76f150e1038bf66da34c5218 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array( @@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase): def testFloatOps(self): for dtype in self.float_types: - # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. - if dtype == np.float16 and self.device == "XLA_CPU": - continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) @@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) def testNumericOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 1e600c44e9af66994686359eb0e1a1e52bea93fd..4cf88fc523735cc2d22e085afb83790c7ebb48e4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dtype=dtype)) def testNeg(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.uint8, np.int8}: self._assertOpOutputMatchesExpected( xla.neg, args=(np.array([1, 2, 3], dtype=dtype),), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 88827cb53bee7bb809d0163d6badcef17e59aa78..98a41981cf30917bc2054c19af5d8176bdfc9862 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -97,10 +97,23 @@ class XLATestCase(test.TestCase): ]) self._numeric_tf_types = set( self.int_tf_types | self._float_tf_types | self.complex_tf_types) - - self._all_types = set( - [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.quantized_tf_types = set( + dtype for dtype in self._all_tf_types if dtype.is_quantized) + + # Quantized types don't have a numpy equivalent, include them in + # all_tf_types but not in all_types. + # TODO(b/115960798): Parametrize tests on TF types instead of numpy types + # and remove all_types. + self._all_types = set(dtype.as_numpy_dtype + for dtype in self._all_tf_types + if not dtype.is_quantized) self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self.signed_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if not dtype.is_unsigned) + self.unsigned_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index d549e7bb59905160a5599fea83667951a60e674d..3f631f91ec442c149b3ea4df3826d98b0419a76f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -611,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -634,6 +635,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -648,6 +650,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14939d4f4fba598318200f71c2eb0270..adcdb6c8f762cb7ea68485167bf7fc8ccb343a51 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -30,14 +30,15 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", - out_ops_file = "ops/xla_jit_op", + include_internal_ops = 1, + out_ops_file = "ops/xla_jit_ops", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( name = "xla_jit_ops", - srcs = ["ops/xla_jit_op.cc"], - hdrs = ["ops/xla_jit_op.h"], + srcs = ["ops/xla_jit_ops.cc"], + hdrs = ["ops/xla_jit_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 922ae7c79a1d3e0ad55bc2858a45cd6be1dc1117..027ca6d2d2f616177d91d9d57d1ff373bab2a754 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, std::function edge_filter) { - // Operators that don't look at the data of their inputs, just the shapes. - const std::unordered_set metadata_ops = { - "Rank", - "Shape", - "ShapeN", - "Size", - }; - std::vector compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g, if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. - if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + if (XlaOpRegistry::IsMetadataOp(node->type_string())) { + return; + } // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f792c520329039c8da63d07ea27fa1c403f5c67d..0362682bd6a8d0977bb09854ef448075fba99273 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -77,7 +79,10 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map* canonicalized_name_to_new_name) { + std::map>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -89,7 +94,20 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); - const FunctionDef& fdef = body->fdef; + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -97,7 +115,7 @@ Status FunctionalizeControlFlowForFunction( // it. std::vector>> nodes_to_associated_functions; - for (auto* n : body->graph->nodes()) { + for (auto* n : g->nodes()) { auto associated_functions = GetAssociatedFunctions(*n, flr); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); @@ -108,57 +126,86 @@ Status FunctionalizeControlFlowForFunction( auto associated_functions = iter.second; for (auto& associated_function : associated_functions) { string name = associated_function.func_name(); - string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + string canonicalized_name = + Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; + bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already functionalized this function, skip functionalization - // but still rewrite the node. - new_name = iter->second; + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + if (associated_function.type() == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { + // For SymbolicGradient, `name` is always "SymbolicGradient", + // which is not very informative. Use node name instead. + new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + name, new_name, associated_function.attrs(), fld, flr, + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + g, n, fld, associated_function, new_name)); } - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - body->graph, n, fld, associated_function, new_name)); } } - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *body->graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *body->graph, fld); + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); - - // Copy signature and ret from original FunctionDef. - *functionalized_fdef.mutable_signature() = fdef.signature(); - *functionalized_fdef.mutable_ret() = fdef.ret(); - functionalized_fdef.mutable_signature()->set_name(new_func_name); - - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; + + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } } return ret_status; @@ -184,7 +231,7 @@ Status FunctionalizeControlFlowPass::Run( {"TPUCompile", "function"}, {"XlaLaunch", "function"}, }; - std::map canonicalized_name_to_new_name; + std::map> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -199,12 +246,15 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); + bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } } } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index ab7cac7100d39377828462f0dee5df98a7319cc3..e9f02201cf6bed5495dff7dff76c5bafe7771516 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -55,17 +55,17 @@ namespace tensorflow { // op registration infrastructure instead of FunctionLibraryRuntime. class GraphCompiler { public: - GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, - Graph* graph, FunctionLibraryRuntime* flib, + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, ScopedStepContainer* step_container) - : xla_context_(xla_context), - device_(device), + : device_(device), graph_(graph), flib_(flib), step_container_(step_container) {} - // Compiles the graph. The results are written in `xla_context` that is passed - // into the compiler. + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. Status Compile(); private: @@ -82,7 +82,6 @@ class GraphCompiler { // using `compiler_`. Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); - XlaContext* xla_context_; XlaCompilationDevice* device_; Graph* graph_; FunctionLibraryRuntime* flib_; diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 46794f7b5070a1a64ac8e16e6a066156a4fa693b..224e5ea123b4905bcfe0947722dbaf4a703f9893 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -62,6 +62,7 @@ tf_kernel_library( "one_hot_op.cc", "pack_op.cc", "pad_op.cc", + "permute_op.cc", "pooling_ops.cc", "qr_op.cc", "quantize_and_dequantize_op.cc", @@ -94,6 +95,7 @@ tf_kernel_library( "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", + "tensor_list_ops.cc", "tile_ops.cc", "topk_op.cc", "training_ops.cc", @@ -113,11 +115,13 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":conv_op_helpers", ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", @@ -156,6 +160,7 @@ tf_kernel_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", @@ -172,6 +177,27 @@ tf_kernel_library( ], ) +cc_library( + name = "conv_op_helpers", + srcs = ["conv_op_helpers.cc"], + hdrs = ["conv_op_helpers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/types:span", + ], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index b3ad0aea84eef601de08909f760699b8700d28f4..a267c0c72fce67d7c22c55a57f8d5ac4ffd2b7e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index df17da4c1ca07053cf63757f1acf2b1a3735e705..47e517a6576d3a848bc41ceb703df2bd778c4a35 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -65,7 +84,10 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + if (DataTypeIsUnsigned(dtype)) { + return xla::Div(x, y); + } auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); @@ -81,12 +103,30 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); +} +XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Div(x, y)); +} +XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 4bd7c74dca2a7cbb51f2a329ac575d635f314516..9bb11fb67e3e4ddc48d68631c60f96c60b921094 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,60 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector broadcast_dims; - std::vector broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { - broadcast_shape.push_back(output_dims[i]); - } - if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index f4106051043859a6786705009d76b02a64cd3ff1..0ae23aa6dfe49048ac5cb8ae00c12432b2e2a2fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -37,6 +37,16 @@ limitations under the License. namespace tensorflow { namespace { +// Used to determine the number of Tensors allowed in a Concat op to prevent +// going over the max gpu parameter memory size. This is an issue because concat +// is variadic and can have an unlimited number of arguments when called. +// Concat ops with more Tensors than this will be split into multiple concat +// ops. +// +// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass +// along with boxing large numbers of parameters. +constexpr int64 kMaxConcatArgsPerOp = 500; + // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; + std::vector partial_concats; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { @@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; + + // Concat is associative, so it can be split into many operations when too + // many arguments are in a single op. This is a temporary workaround for + // b/112613927 where too many parameters in an XlaLaunchOp later result in + // too many parameters to a single GPU kernel. + if (i && i % kMaxConcatArgsPerOp == 0) { + partial_concats.push_back( + xla::ConcatInDim(ctx->builder(), input_data, axis)); + input_data.clear(); + } } + // Add any inputs that have not been put into another concat yet. + partial_concats.insert(partial_concats.end(), input_data.begin(), + input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); + // Don't add an additional "identity" concatenate for better readibility of + // IR. + if (partial_concats.size() == 1) { + ctx->SetOutput(0, partial_concats.front()); + } else { + ctx->SetOutput(0, + xla::ConcatInDim(ctx->builder(), partial_concats, axis)); + } } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index da8cf3fc6fa694f592280f8c249d317827d9cd09..2628ef8e2454976aeff3859fa5dc1d8e106f32e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX64: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex64(proto_.scomplex_val(0), + proto_.scomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9a1be494066e4f935a1d818bc86c86333e34fae --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -0,0 +1,509 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Returns the expanded size of a filter used for depthwise convolution. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { + int num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape expanded_shape = shape; + expanded_shape.set_dimensions( + num_dims - 1, + shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); + return expanded_shape; +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tensor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, + xla::XlaBuilder* builder) { + xla::Shape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = + filter_shape.dimensions(filter_shape.dimensions_size() - 1); + int64 input_feature = + filter_shape.dimensions(filter_shape.dimensions_size() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + std::vector expanded_feature_broadcast_dims( + expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dimensions_size() - 2}); +} + +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dimensions_size() - 2; + int64 output_feature_dim = filter_shape.dimensions_size() - 1; + int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); + int64 input_feature = filter_shape.dimensions(input_feature_dim); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + xla::Shape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dimensions( + output_feature_dim, depthwise_multiplier * input_feature); + return xla::Reshape( + filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); +} + +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. +xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { + auto masked_expanded_filter = + xla::Select(CreateExpandedFilterMask(filter_shape, builder), + filter_backprop, xla::ZerosLike(filter_backprop)); + + auto elem_type = filter_shape.element_type(); + return xla::Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the select above guarantees + // that only one element is non zero, so there cannot be accumulated + // precision error. + xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), + CreateScalarAddComputation(elem_type, builder), + {filter_shape.dimensions_size() - 2}), + xla::AsInt64Slice(filter_shape.dimensions())); +} + +// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA +// convolutions (as currently implemented). +Status CheckConvAttrs(const ConvOpAttrs& attrs) { + const int num_dims = attrs.num_spatial_dims + 2; + if (attrs.strides.size() != num_dims) { + return errors::InvalidArgument("Sliding window strides field must specify ", + num_dims, " dimensions"); + } + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not yet support strides in the batch and " + "depth dimensions."); + } + if (attrs.dilations.size() != num_dims) { + return errors::InvalidArgument("Dilations field must specify ", num_dims, + " dimensions"); + } + if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not support dilations in the batch and " + "depth dimensions."); + } + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + if (attrs.dilations[input_dim] < 1) { + return errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + attrs.dilations[input_dim]); + } + } + return Status::OK(); +} + +// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes +// to TensorShapes. +Status ConvBackpropComputeDimensionsV2XlaShapes( + StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, + const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, + absl::Span dilations, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + TensorShape input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); + return ConvBackpropComputeDimensionsV2( + label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape, dilations, strides, padding, data_format, + dims); +} + +} // anonymous namespace + +xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx) { + ConvOpAttrs attrs; + attrs.num_spatial_dims = num_spatial_dims; + attrs.depthwise = depthwise; + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + + string data_format; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); + if (!FormatFromString(data_format, &attrs.data_format)) { + return errors::InvalidArgument("Invalid data format: ", data_format); + } + + return attrs; +} + +xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = conv_input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); + // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + + // For 2D convolution, there should be 4 dimensions. + int num_dims = attrs.num_spatial_dims + 2; + if (input_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString()); + } + if (filter_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument( + "filter must be ", num_dims, + "-dimensional: ", filter_shape.DebugString()); + } + + // The last two dimensions of the filter are the input and output shapes. + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + if (in_depth != input_shape.dimensions(feature_dim)) { + return errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, " vs ", + input_shape.dimensions(feature_dim)); + } + + if (attrs.depthwise) { + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); + } + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims, 1); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); + dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = attrs.strides.at(dim); + rhs_dilation[i] = attrs.dilations.at(dim); + + int64 unused_output_size; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + input_shape.dimensions(dim), filter_shape.dimensions(i), + rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + return xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); +} + +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + int num_dims = attrs.num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + auto* builder = filter.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(out_backprop)); + + xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, + out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, + attrs.data_format, &dims)); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); + + std::vector kernel_spatial_dims(attrs.num_spatial_dims); + std::vector> padding(attrs.num_spatial_dims); + std::vector lhs_dilation(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = attrs.dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + return xla::ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / + filter_shape.dimensions(attrs.num_spatial_dims + 1) + : 1); +} + +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = activations.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, + builder->GetShape(activations)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(gradients)); + const xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // The last two dimensions of the filter are the input and output shapes. + int num_dims = attrs.num_spatial_dims + 2; + int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64 pad_before = + attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; + } + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); + + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } + + return filter_backprop; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..6e1b70a47850ae5c05939f8dfb7ec129c031df21 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static xla::StatusOr Create(int num_spatial_dims, bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector dilations; + std::vector strides; + Padding padding; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +xla::StatusOr MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs); +xla::StatusOr MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 674720e22fbf9d995e74c7dbd0ef7d7765941867..cd7c820be0b6029514ff74288e7bdd3f75b5d6b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,12 +15,17 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,250 +38,28 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { - namespace { -// Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. -TensorShape ExpandedFilterShapeForDepthwiseConvolution( - const TensorShape& shape) { - int num_dims = shape.dims(); - CHECK_GE(num_dims, 2); - TensorShape expanded_shape = shape; - expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * - shape.dim_size(num_dims - 1)); - return expanded_shape; -} - -// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return xla::Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); -} - -// Create a mask for depthwise convolution that will make a normal convolution -// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tensor -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 -// -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 -// -// and divide B it by 2 to get -// 0 0 1 1 2 2 -// -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. -xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. - expanded_feature_iota = - xla::Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); - - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); -} - -// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to -// build a depthwise convolution. -xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, - const xla::XlaOp& filter) { - int64 input_feature_dim = filter_shape.dims() - 2; - int64 output_feature_dim = filter_shape.dims() - 1; - int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); - int64 input_feature = filter_shape.dim_size(input_feature_dim); - - // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = filter_shape; - implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); - implicit_broadcast_filter_shape.set_dim(output_feature_dim, - depthwise_multiplier * input_feature); - return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); -} - -// Reduces the results of the convolution with an expanded filter to the -// non-expanded filter. -xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, - const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter_backprop, - xla::XlaBuilder* builder) { - auto masked_expanded_filter = xla::Select( - CreateExpandedFilterMask(filter_shape, builder), filter_backprop, - CreateExpandedZero(filter_shape, dtype, builder)); - return xla::Reshape( - // This reduce does not need inputs to be converted with - // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with - // ExpandedZero guarantees that only one element is non zero, so there - // cannot be accumulated precision error. - xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), - filter_shape.dim_sizes()); -} - class ConvOp : public XlaOpKernel { public: explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape input_shape = ctx->InputShape(0); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, ..., in_depth, out_depth] - const TensorShape filter_shape = ctx->InputShape(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES( - ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("input must be ", num_dims(), "-dimensional", - input_shape.DebugString())); - OP_REQUIRES( - ctx, filter_shape.dims() == num_dims(), - errors::InvalidArgument("filter must be ", num_dims(), - "-dimensional: ", filter_shape.DebugString())); - - // The last two dimension of the filter are the input and output shapes. - const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); - - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", input_shape.dim_size(feature_dim))); - - xla::XlaOp filter = ctx->Input(1); - if (depthwise_) { - filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); - } - - xla::ConvolutionDimensionNumbers dims; - std::vector window_strides(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_, 1); - std::vector rhs_dilation(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - - dims.set_input_batch_dimension(batch_dim); - dims.set_output_batch_dimension(batch_dim); - dims.set_input_feature_dimension(feature_dim); - dims.set_output_feature_dimension(feature_dim); - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_input_spatial_dimensions(dim); - dims.add_kernel_spatial_dimensions(i); - dims.add_output_spatial_dimensions(dim); - window_strides[i] = strides_.at(dim); - rhs_dilation[i] = dilations_.at(dim); - - int64 unused_output_size; - OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), filter_shape.dim_size(i), - rhs_dilation[i], window_strides[i], padding_, - &unused_output_size, &padding[i].first, &padding[i].second)); - } - - xla::XlaOp conv = xla::ConvGeneralDilated( - ctx->Input(0), filter, window_strides, padding, lhs_dilation, - rhs_dilation, dims, - /*feature_group_count=*/depthwise_ ? in_depth : 1); - ctx->SetOutput(0, conv); + xla::StatusOr conv = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); + OP_REQUIRES_OK(ctx, conv.status()); + ctx->SetOutput(0, conv.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); @@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel { public: explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - TensorShape input_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - - const TensorShape filter_shape = ctx->InputShape(1); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - auto filter = ctx->Input(1); - auto out_backprop = ctx->Input(2); - - // The input gradients are computed by a convolution of the output - // gradients and the filter, with some appropriate padding. See the - // comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(batch_dim); - dnums.set_output_batch_dimension(batch_dim); - dnums.set_input_feature_dimension(feature_dim); - dnums.set_output_feature_dimension(feature_dim); - - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the gradient. - dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); - dnums.set_kernel_output_feature_dimension(num_spatial_dims_); - - std::vector kernel_spatial_dims(num_spatial_dims_); - std::vector> padding(num_spatial_dims_); - std::vector lhs_dilation(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(i); - dnums.add_output_spatial_dimensions(dim); - - kernel_spatial_dims[i] = i; - padding[i] = {dims.spatial_dims[i].pad_before, - dims.spatial_dims[i].pad_after}; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations_[dim]; - } - - // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); - - // activation gradients - // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums, - /*feature_group_count=*/ - depthwise_ ? out_backprop_shape.dim_size(feature_dim) / - filter_shape.dim_size(num_spatial_dims_ + 1) - : 1); - - ctx->SetOutput(0, in_backprop); + TensorShape input_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); + xla::Shape input_shape = + TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + + xla::StatusOr in_backprop = + MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, + ctx->Input(1), ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, in_backprop.status()); + ctx->SetOutput(0, in_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); @@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel { public: explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - - OP_REQUIRES( - ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape activations_shape = ctx->InputShape(0); - TensorShape filter_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp activations = ctx->Input(0); - xla::XlaOp gradients = ctx->Input(2); - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); - - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); - - std::vector> padding(num_spatial_dims_); - std::vector rhs_dilation(num_spatial_dims_); - std::vector window_strides(num_spatial_dims_); - std::vector ones(num_spatial_dims_, 1); - - // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_output_spatial_dimensions(i); - } - dnums.set_output_batch_dimension(num_spatial_dims_); - dnums.set_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(dim); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int64 padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; - - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64 pad_before = - padding_ == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations_[dim]; - } - - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (depthwise_) { - filter_backprop = ContractFilterForDepthwiseBackprop( - ctx, filter_shape, ctx->input_type(0), filter_backprop, b); - } - ctx->SetOutput(0, filter_backprop); + TensorShape filter_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); + xla::Shape filter_shape = + TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); + + xla::StatusOr filter_backprop = MakeXlaBackpropFilterConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, + ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, filter_backprop.status()); + ctx->SetOutput(0, filter_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector dilations_; - std::vector strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d181a183d412f9c269dd5ec608b388f..234f7b4a019c9aac4bac4f906ddbae166ecd9a80 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 6653944a911588b7bc88d67b8cdd2c17850530f0..516ead4bfe89b4ddeee11dcc6410a838d04f28a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 33a73fe5fdf403e513be085dd7bcea3255277b4a..921b4340c0ac674a5ad7d17aaf54f1cf36975151 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + OP_REQUIRES(context, output_size <= kint32max, + errors::InvalidArgument("Need output_size <= kint32Max, got ", + output_size)); xla::XlaOp score_thresh = context->Input("score_threshold"); xla::XlaOp iou_thresh = context->Input("iou_threshold"); @@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); - // num_valid is scalar. - xla::XlaOp num_valid = xla::Reduce( + // num_valid is scalar. Value should be bound by output_size. + xla::XlaOp num_valid_total = xla::Reduce( ones_included, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); + xla::XlaOp num_valid = + xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); xla::XlaOp output_tuple = TopK(scores_included, output_size); xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d9a0257b70bcf302dea77db2e9f7fa7b4543e038..7b2bb4a7c50fc954237e09a32f71009f790b60d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -std::vector Make1DKernel(int64 n) { +xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return kernel; + return xla::ConstantR1(builder, kernel); } // Kernels with more than 16 spatial elements are considered intense and the @@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, absl::Span kernel_size, int64 channels) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + auto depthwise_kernel = xla::Broadcast( + xla::Zero(builder, xla::F32), + {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - auto diag = xla::ConvertElementType( - xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, - 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); return xla::Mul( - xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + Make1DKernel(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, absl::Span kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - - auto diag = xla::ConvertElementType( - xla::Eq( - xla::Broadcast(channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); - if (dim == 1) { - return xla::Mul( - diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}); - } - return xla::Mul(diag, - xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + auto depthwise_kernel = + xla::Broadcast(xla::Zero(builder, xla::F32), + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); + return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + /*broadcast_dimensions=*/{dim}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { dimension_numbers.add_input_spatial_dimensions(1 + i); dimension_numbers.add_output_spatial_dimensions(1 + i); @@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, upper_padding[0]}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); output = xla::ConvGeneralDilated( @@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // Add broadcasts to handle expanding from a size == 1 dimension to a @@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(1 + i); - dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_input_spatial_dimensions(i + 1); + dimension_numbers.add_output_spatial_dimensions(i + 1); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = @@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/{1, dims.stride[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 3d81ae9eb89a80e5b89b180ad77521c5ed15e79d..f210bfbd886e48b8d7972393ed1899491486646c 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + // The argmax function expects row-major layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::S64, output_shape.dim_sizes()); + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + auto shape_status = b.GetShape(arg); + OP_REQUIRES_OK(ctx, shape_status.status()); + xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); + *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( + xla::ShapeUtil::Rank(arg_shape)); + arg_shapes.push_back(std::move(arg_shape)); + } // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = - xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); break; case 2: - output = - xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0764e5503db583351e92a144b2c361e8875161d3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class DataFormatVecPermuteOp : public XlaOpKernel { + public: + explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); + OP_REQUIRES( + ctx, src_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + TensorFormat data_format; + OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); + OP_REQUIRES( + ctx, dst_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + const TensorShape input_tensor_shape = ctx->InputShape(0); + int input_rank = input_tensor_shape.dims(); + OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2, + errors::InvalidArgument( + "Input must be a vector or matrix, but got shape ", + input_tensor_shape.DebugString())); + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(0) == 4, + errors::InvalidArgument( + "First dimension of input must be of size 4, but got shape ", + input_tensor_shape.DebugString())); + if (input_rank == 2) { + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(1) == 2, + errors::InvalidArgument( + "Second dimension of 2D input must be of size 2, but got shape ", + input_tensor_shape.DebugString())); + } + std::vector dst_indices(4, 0); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (src_format_[i] == dst_format_[j]) { + dst_indices[i] = j; + break; + } + } + } + auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); + if (input_rank == 2) { + keys = xla::BroadcastInDim( + keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + } + auto sorted = xla::Sort(keys, ctx->Input(0), 0); + auto output = xla::GetTupleElement(sorted, 1); + ctx->SetOutput(0, output); + } + + private: + string src_format_; + string dst_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); +}; + +// TODO(b/115384656): Support DT_INT64. +REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32), + DataFormatVecPermuteOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28db71075fb8da269c55edbdb667193e..8eee5b12991fb377203d780cecd8916952bd699a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector window_dimensions; std::vector window_strides; + std::vector base_dilations; + std::vector window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaReduceWindow") .CompileTimeConstInput("window_dimensions") .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("base_dilations") + .CompileTimeConstInput("window_dilations") .CompileTimeConstInput("padding"), ReduceWindowOp); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1ce9856a3c2854fd2776827d6c4b76f..57afd608de820573821d605cadcc8779474b5fd6 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel { } auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); + *reducer, window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); output = XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 25a5bcbe1dd27d741ce3b74125ba9ce425ee78f3..0c32b8def0f7b741c93e803f8359b6504087e257 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::LiteralSlice& start_literal, - const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, - Tensor* output) { +xla::StatusOr CreateRangeTensor( + const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal, ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); - *output = Tensor(DataTypeToEnum::v(), TensorShape({size})); - auto flat = output->flat(); - T val = start; - for (int64 i = 0; i < size; ++i) { - flat(i) = val; - val += delta; - } - return Status::OK(); + return xla::ConstantR0(builder, start) + + xla::ConstantR0(builder, delta) * + xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType(), + size); } class RangeOp : public XlaOpKernel { @@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); DataType type = input_type(0); - Tensor output; - Status status; + xla::StatusOr output; switch (type) { case DT_INT32: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_INT64: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_FLOAT: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_DOUBLE: - status = CreateRangeTensor(start, limit, delta, &output); + output = CreateRangeTensor(start, limit, delta, ctx->builder()); break; default: - status = errors::InvalidArgument("Invalid type for Range ", + output = errors::InvalidArgument("Invalid type for Range ", DataTypeString(type)); } - OP_REQUIRES_OK(ctx, status); - ctx->SetConstantOutput(0, output); + OP_REQUIRES_OK(ctx, output.status()); + ctx->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 2e0a69b70ef91fb5fee8aac888fdc90517c1356e..c8a0f31a0375abacaca26688a23f4835e11c692e 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01ccb303091a6d37d1aeb4b2a3377dc638..45f03d8c2175fc8b425b329b90893bb54d7f1d87 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), context->Input("values")); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..74d4fcc425bdadb70a7bedf2487deaf6c4a4f7b9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorList operators. + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, + TensorShape* tensor_list_shape) { + auto shape_or_status = builder->GetShape(op); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + tensor_list_shape); +} + +class TensorListReserveOp : public XlaOpKernel { + public: + explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + + TensorShape tensor_shape; + tensor_shape.AddDim(num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp); +}; + +REGISTER_XLA_OP(Name("TensorListReserve") + .CompileTimeConstInput("element_shape") + .CompileTimeConstInput("num_elements"), + TensorListReserveOp); + +class EmptyTensorListOp : public XlaOpKernel { + public: + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure( + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Use TensorListReserve instead.")); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); +}; + +REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); + +class TensorListElementShapeOp : public XlaOpKernel { + public: + explicit TensorListElementShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + shape.RemoveDim(0); + + switch (shape_type_) { + case DT_INT64: + ctx->SetOutput(0, xla::ConstantR1(b, shape.dim_sizes())); + break; + case DT_INT32: { + std::vector size; + for (int64 s : shape.dim_sizes()) { + size.push_back(s); + } + ctx->SetOutput(0, xla::ConstantR1(b, size)); + break; + } + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported shape type requested")); + return; + } + } + + private: + DataType shape_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); +}; + +REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); + +class TensorListPushBackOp : public XlaOpKernel { + public: + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp list = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(1); + + xla::XlaOp ta = xla::GetTupleElement(list, 0); + xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + ctx->SetOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); + +class TensorListPopBackOp : public XlaOpKernel { + public: + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); + + index = index - xla::ConstantR0(b, 1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); + + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetOutput(1, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 8597e7f139d8d32b7e08782e70a4ee44d02618f2..1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -31,6 +31,22 @@ cc_library( ], ) +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 64f2d781a694393f6fabcd9f443cdb4911921c97..5400e8834cb9807f6dd71abe7789b2672e29e905 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -100,16 +100,6 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose - // HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs, &precision_proto); - } - xla::DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e402ef855cd7c114332d84032bc869232404fc8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector broadcast_dims; + std::vector broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::BroadcastInDim( + input, + xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), + broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..591e696f06b994a7fdea58bc95ba785f683ce7d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 38dfde165df47ca78a25a068a901cd1071aa55e2..2b1c2ced925d9fee7392986015a6e716a94d356f 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -38,12 +38,10 @@ xla::StatusOr XlaScatter( combiner, xla::XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); - TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - absl::Span buffer_dims = - xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -81,104 +79,129 @@ xla::StatusOr XlaScatter( } } - // Shape of the non-indexed dimensions of the buffer. - std::vector buffer_shape_post_axes( - buffer_dims.begin() + num_index_dims, buffer_dims.end()); - - // Flatten the major dimensions of indices and updates into a single dimension - // for ease of iteration. - std::vector flat_indices_shape({num_indices}); - if (indices_are_vectors) { - flat_indices_shape.push_back(num_index_dims); + // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of + // shape [3,3]: + // NOTE: ***This case will not be generated by any of the tf.scatter ops.*** + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[3,2] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={0}, + // inserted_window_dims={1}, + // scatter_dims_to_operand_dims={1}, + // index_vector_dim=1 + // + // + // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of + // shape [3,3]: + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[2,3] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // + // + // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of + // shape [3,3,2] + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // updates = s32[2,2] parameter(2) + // scatter = s32[3,3,2] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0,1}, + // index_vector_dim=1 + // + // + // Example of a scatter updating slices of shape [] in a tensor of shape [1,1] + // + // operand = s32[1,1] parameter(0) + // indices = s32[1] parameter(1) + // updates = s32[1] parameter(2) + // scatter = s32[1,1] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // Note that updates operand would be broadcasted into [1] in this case. + // + + xla::ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(indices_are_vectors + ? indices_shape.dimensions_size() - 1 + : indices_shape.dimensions_size()); + + int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); + int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 num_window_dims_in_updates = buffer_rank - num_index_dims; + + // If the rank of `updates` is 0 and does not match the expected rank of + // updates, broadcast `updates` to the expected shape of updates. + auto new_updates = updates; + std::vector expected_updates_dims(indices_dims.begin(), + indices_dims.end()); + for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + expected_updates_dims.push_back(buffer_shape.dimensions(dim)); + } + int64 expected_updates_rank = expected_updates_dims.size(); + if (updates_rank == 0 && expected_updates_rank != 0) { + new_updates = xla::Broadcast(updates, expected_updates_dims); + TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); + updates_rank = xla::ShapeUtil::Rank(updates_shape); } - std::vector flat_updates_shape({num_indices}); - flat_updates_shape.insert(flat_updates_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = xla::Reshape(indices, flat_indices_shape); - auto flat_updates = xla::Reshape(updates, flat_updates_shape); - auto init = {flat_indices, flat_updates, buffer}; - - // Constructs the loop body. The implementation of scatter is essentially: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // update = dynamic-slice(updates, i) - // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) { - auto indices = loop_vars[0]; - auto updates = loop_vars[1]; - auto buffer = loop_vars[2]; - - auto zero_index = xla::ConstantLiteral( - body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); - - // Slice the i-th index from the indices array. - xla::XlaOp index; - auto indices_offset = xla::Reshape(i, {1}); - if (indices_are_vectors) { - indices_offset = xla::Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - - index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = xla::Collapse(index, {0, 1}); - } else { - index = xla::DynamicSlice(indices, indices_offset, {1}); + if (updates_rank > 0) { + for (int64 i = (updates_rank - num_window_dims_in_updates); + i < updates_rank; ++i) { + dim_numbers.add_update_window_dims(i); } + } - // Discard updates with negative indices, since some users expect this. - auto index_in_range = xla::ReduceAll( - xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), - xla::CreateScalarAndComputation(xla::PRED, body_builder)); - - // Make the index in bounds to prevent implementation defined behavior. - index = xla::Max(index, zero_index); - index = xla::Pad( - index, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - - // Slice the i-th index from the updates array. - auto updates_offset = xla::Reshape(i, {1}); - updates_offset = xla::Pad( - updates_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - std::vector flat_updates_slice_shape({1}); - flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - auto update = - xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - - // Unflatten the major (iteration) dimensions of the slice to their - // original shape. - std::vector updates_slice_shape(num_index_dims, 1); - updates_slice_shape.insert(updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - update = xla::Reshape(update, updates_slice_shape); - - // Apply the update to the buffer. If there is a combiner, use it to merge - // the current values with the update. - auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); + for (int64 i = 0; i < num_index_dims; ++i) { + dim_numbers.add_inserted_window_dims(i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + + // Build the combiner computation. + xla::XlaComputation combiner_computation; + { + xla::XlaBuilder cb("scatter-combiner"); + auto xla_scalar_shape = + xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); + auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1"); if (combiner) { - update = combiner(current_value, update, body_builder); + combiner(p0, p1, &cb); } - // Use the current value instead of the update if the index is out of - // bounds. - update = xla::Select(index_in_range, update, current_value); - // Apply the update. - buffer = xla::DynamicUpdateSlice(buffer, update, index); - - return std::vector{indices, updates, buffer}; - }; - - TF_ASSIGN_OR_RETURN(auto outputs, - XlaForEachIndex(num_indices, indices_shape.element_type(), - body_fn, init, "scatter", builder)); - return outputs[2]; + combiner_computation = cb.Build().ConsumeValueOrDie(); + } + + VLOG(3) << "Scatter op:"; + VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); + VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape); + VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); + VLOG(3) << " Scatter Dimension Numbers: "; + VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); + VLOG(3) << " update_window_dims: [" + << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; + VLOG(3) << " inserted_window_dims: [" + << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; + VLOG(3) << " scatter_dims_to_operand_dims: [" + << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") + << "]"; + + return xla::Scatter(buffer, indices, new_updates, combiner_computation, + dim_numbers); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 13a5f1b850a612bddeeac39bef431c19925351ca..4cf478c4b9b4316f1cf43f45d1bf90afa648fb11 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -34,7 +34,11 @@ namespace tensorflow { // Otherwise, `indices_are_vectors`, then indices are multidimensional and the // minor dimension of `indices` represents a vector of indices. // -// If any indices are negative, the corresponding update is discarded. +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. // // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index ed452bceeb5a599ccbb27c38f80c08777db8529b..15f4c38da29507da9e092c1d5725b5f95a81d1b9 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,48 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector int64_values = {1, 2, 3}; - xla::Literal int64_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int64_values)); - } + std::vector int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); +} + +template +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types, std::pair, + std::pair, std::pair, + std::pair>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector int32_values = {10, 11}; - xla::Literal int32_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int32_values)); + Tensor host_tensor; + std::vector int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector qint32_values = {10, 11}; - test::ExpectTensorEqual(host_tensor, - test::AsTensor(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, + &host_tensor) + .ok()); + std::vector qint_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 02363500efe1a11348eaf7d8b99da76307acdd3c..bd2c0a5ee88869ba60701c0a7ace05857452eed9 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -121,8 +121,8 @@ Wraps the XLA DynamicSlice operator, documented at DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each dimension is passed in size_indices, which specify the end point of exclusive slice intervals in each -dimension -- [start, start + size). The shape of start_indices must be rank == -1, with dimension size equal to the rank of operand. +dimension -- [start, start + size). The shape of start_indices must have rank 1, +with dimension size equal to the rank of operand. input: A `Tensor` of type T. @@ -131,7 +131,8 @@ start_indices: Rank 1 tensor of N integers containing the starting indices of start_indices: List of N integers containing the slice size for each dimension. Each value must be strictly greater than zero, and start + size - must be less + must be less than or equal to the size of the dimension to avoid + implementation defined behavior. )doc"); REGISTER_OP("XlaDynamicUpdateSlice") @@ -282,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") @@ -353,12 +356,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27dd18a9bbd5aceece41aaf61eb185acb537b3b6..5e86b5d8ec0a2690f004bc67decea09185d9cbb6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -320,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -332,22 +334,27 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) @@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 20f2ce2919701731ef6e90d368b67545af95e8f9..72b240996fb4d9dcb5f5dfd919da618cbae08c16 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/gtl/flatmap.h" +#include "absl/container/flat_hash_map.h" namespace tensorflow { /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( @@ -30,9 +30,9 @@ namespace tensorflow { } } -static gtl::FlatMap* +static absl::flat_hash_map* CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap; + auto* result = new absl::flat_hash_map; auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { @@ -103,15 +103,15 @@ CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const absl::flat_hash_map& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = - CreateResourceOpInfoMap(); + static absl::flat_hash_map* + op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap& op_infos = + const absl::flat_hash_map& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index a85ef040a7b65c2f6e405c3444eaef3019137b4b..956f597301d28d781a9df7ab2086ed79d4c8bf9d 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - gtl::FlatMap known_resource_ops; + absl::flat_hash_map known_resource_ops; for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 9d1992205b02665b99b1bd15b7b65a1fb8c35a51..b589512dcdfa32050281120aba6a5ae89a980c2f 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector dimensions(rank); std::vector layout(rank); @@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); - - *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); - return Status::OK(); + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 58240b9c965a194b9380ac7cd477ce7344e5ebe3..f7e34a5b40c2f9244c029ed325a76322b8cf54dd 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape); +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec..f31bfb45a2f4db270446eb59259969dc0ab63a8e 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map BuildNodeIndex(const Graph& graph) { + std::unordered_map index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed23f3fca0f59b131dc73152e0947b72..350a868568531c0d073e0cf600327d1ff9d62e3a 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index d6f42bac86f1ef359531d67b652d43d851d7ac02..01dd3ba10fec85e6b1d411fbd32fbf9c58b5fe11 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def, } if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - return false; + // Gradient op has "f" attr, which is set to the function we are getting + // gradient for. We need to functionalize the gradient function. + return true; } for (const auto& iter : node_def.attr()) { @@ -357,17 +357,18 @@ std::vector GetAssociatedFunctions( if (flr->GetFunctionLibraryDefinition()->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); - results.emplace_back(AssociatedFunctionInfo(op, attrs)); + results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. + // This is a SymbolicGradient op. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { if (iter.second.has_func()) { VLOG(2) << "Found function attr for node " << node.name() << ": " << iter.first << " = " << iter.second.func().name(); - results.emplace_back(AssociatedFunctionInfo( + results.emplace_back(AssociatedFunctionInfo::FunctionAttr( iter.second.func().name(), iter.second.func().attr(), iter.first)); } } @@ -410,6 +411,21 @@ Status RewriteAssociatedFunction( graph->RemoveNode(node); break; } + case AssociatedFunctionInfo::kSymbolicGradient: { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr( + node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func)); + GradientDef gradient_def; + gradient_def.set_function_name(func.name()); + gradient_def.set_gradient_func(rewritten_function_name); + string original_grad_func = fld->FindGradient(func.name()); + if (original_grad_func.empty()) { + TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); + } else if (original_grad_func != rewritten_function_name) { + TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def)); + } + break; + } case AssociatedFunctionInfo::kFunctionAttr: { // Change function attr to rewritten functions. NameAttrList func; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 6065d0bb9a3abd23b8911c5049914be8a5f23b99..53eab8b63e2fc8aa3dfb0bacfe065897ca775bd0 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -65,21 +65,33 @@ uint32 GetXLARandomSeed(); class AssociatedFunctionInfo { public: enum AssociatedFunctionType { - kFunctionCallNode = 0, - kFunctionAttr = 1, + kFunctionAttr = 0, + kFunctionCallNode = 1, + kSymbolicGradient = 2, }; - // The node is a function call. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) - : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} - // The function is an attr of the node. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, - const string& attr_name) - : type_(kFunctionAttr), - func_name_(func_name), - attrs_(attrs), - attr_name_(attr_name) {} + static AssociatedFunctionInfo FunctionAttr(const string& func_name, + const AttrValueMap& attrs, + const string& attr_name) { + return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); + } + + // The node is a function call. + static AssociatedFunctionInfo FunctionCall(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, + /*attr_name=*/""); + } + + // The node is a SymbolicGradient op. + static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, + /*attr_name=*/""); + } AssociatedFunctionType type() const { return type_; } @@ -90,6 +102,13 @@ class AssociatedFunctionInfo { const AttrValueMap& attrs() const { return attrs_; } private: + AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, + const AttrValueMap& attrs, const string& attr_name) + : type_(type), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + // Available for all instances. AssociatedFunctionType type_; string func_name_; @@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def, // Gets functions associated with the node. Current cases: // 1. For function call node, its function name; -// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient", +// and returned attrs will be this node's attributes; +// 3. For nodes like XlaWhile/XlaIf, all their function attributes. std::vector GetAssociatedFunctions( const Node& node, FunctionLibraryRuntime* flr); // Changes associated functions for the node. Current cases: // 1. For function call node, creates a new node with the new function name and // remove the old node; -// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +// 2. For SymbolicGradient op, add or replace GradientDef in +// FunctionLibraryDefinition; +// 3. For nodes like XlaWhile/XlaIf, modify their function attributes. Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1bfaa6cab0d896ee074cfd4e2b283ae4..d00b1376620c0c9d112c7d7426758f6d3f25e86f 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h index bda667eb1f16b80da415c7c5205df96a4ae93e4c..6354216eee7978dc2b4a59f5792a70f67d530b9b 100644 --- a/tensorflow/compiler/tf2xla/type_util.h +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -25,6 +25,14 @@ namespace tensorflow { // Converts a Tensorflow DataType to an XLA PrimitiveType. Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); +// N.B.: there is intentionally no function to convert an XLA PrimitiveType to +// a TensorFlow DataType. The mapping from TF types to XLA types is not +// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the +// inverse would not be a well-defined function. If you find that you want the +// inverse mapping, then most likely you should be preserving the original +// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow +// type. + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 105f3b61d59acc7ed502216a5e0ceb69ee914131..b2c57e88803e0661a9a514f844dff97ff9edf2ea 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, std::unique_ptr graph = GetGraph(fbody); + // Clear the "_kernel" attribute if it is set to "host". This is used to + // indicate that a computation should happen on the host instead of the + // accelerator, but doesn't make sense in XLA. + const char* const kKernelAttr = "_kernel"; + for (Node* n : graph->nodes()) { + string value; + if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + n->ClearAttr(kKernelAttr); + } + } + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. @@ -325,8 +336,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, - step_container.get()); + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); @@ -334,10 +344,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, } // Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. +// `args` is the list of input arguments, `retvals` is the list of retvals +// produced by _Retval operators, in index order. // If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 23d04d43b358e858ad1ab2463322ce0ab93b23c2..bc44301d405102921de21da4bd9407032783838c 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -20,21 +20,6 @@ limitations under the License. namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { - // TODO(b/34339814): implement inverse erf for double types and remove this - // workaround. - if (kdef->op() == "RandomStandardNormal") { - kdef->clear_constraint(); - // Change the type constraint to permit only DTD_FLOAT. - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name("dtype"); - attr_constraint->mutable_allowed_values()->mutable_list()->add_type( - DT_FLOAT); - return true; - } - // TODO(b/26783907): The CPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2a9eaeee146bf6d792e010df7e041f9986b2c77e..dd3498ef7aa242d3ad946cae5f60bc2c8853a342 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a3a0d10cc06cd4afceec728b7dbe287389099b9d..aa00a454968ad29495e34dc080e55b62bb0b5f7b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -255,6 +255,11 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); + // Wraps OpKernelContext's allocate_output method while providing special + // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the + // type to allow mapping for variant to more generic types. + Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index b0eeee3174eda7f552f1d8a1d5ece877e93f94ab..91d48125f1d21092db7e5f9307e44af9c16e4e2b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible compile time constant inputs."; return false; } + if (x.is_metadata_op != y.is_metadata_op) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible values for is_metadata_op."; + return false; + } return true; } @@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { return &it->second.front()->compile_time_constant_inputs; } +/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return false; + } + + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // is_metadata_op, so only the first match is returned. + return it->second.front()->is_metadata_op; +} + std::vector XlaOpRegistry::BackendNames() { std::vector names; XlaOpRegistry& registry = Instance(); @@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { + registration_->is_metadata_op = true; + return *this; +} + std::unique_ptr XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 74a4885f1f029628817f6ec3a36fcb98719d6a41..4b2c2bacd647b3e6fe500a942b116772550195ce 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { + {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -136,6 +137,10 @@ class XlaOpRegistry { static const std::unordered_set* CompileTimeConstantInputs( const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -192,6 +197,10 @@ class XlaOpRegistry { // Names of arguments that must be compile-time constants. std::unordered_set compile_time_constant_inputs; + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -256,6 +265,10 @@ class XlaOpRegistrationBuilder { // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); + std::unique_ptr Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 76e36f3c46b22742b6cf0c86e89d17899338a60f..cc7390c6e60375b4c31c38f9f7dee25730f8f51e 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -193,6 +193,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", ], ) @@ -244,6 +245,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f825f67b447514a416f3a49ac8aad9dcf505f5a7..dc097f3696e22d75d7dc72ec4877a9c8b5dda059 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -220,6 +220,8 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 25cc37edc43c28a636797c310c8882eea09a0ef3..ff0ec76a7f9b62fce0f14beae688cb0dd74847a1 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -97,13 +97,11 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // Create and run a program which produces a tuple with one element per - // parameter, then return the tuple's constituent buffers. - std::vector param_shapes(program_shape.parameters().begin(), - program_shape.parameters().end()); - auto fake_input_tuple = - MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); - return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); + std::vector> results; + for (const Shape& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(shape, client)); + } + return results; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 95ff6432a591f87845729b180397e33a85e5e9a5..e7cf9ae36389056c4732285dd9667f4dcd4115bd 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: @@ -1276,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - absl::Span operands, - const Shape& shape) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1289,6 +1293,32 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); + instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape; + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -1785,9 +1815,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } @@ -1796,6 +1826,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1806,7 +1838,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferReduceWindowShape(operand_shape, init_shape, @@ -2289,7 +2322,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; - tensorflow::gtl::FlatSet related_calls; // Related computations. + absl::flat_hash_set related_calls; // Related computations. std::queue worklist; worklist.push(root->id()); related_ops.insert(root->id()); @@ -2681,8 +2714,18 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape) { - return builder->CustomCall(call_target_name, operands, shape); + absl::Span operands, const Shape& shape, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, @@ -2795,10 +2838,12 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, - padding); + base_dilations, window_dilations, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d0c59fa6f27bc265c0868734ed95a196002fbd2e..933c0e7b4458344c1f3a98f2750566c92fdba264 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -577,11 +577,10 @@ class XlaBuilder { absl::Span operands); // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - XlaOp CustomCall(const string& call_target_name, - absl::Span operands, const Shape& shape); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -673,6 +672,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -1029,7 +1030,7 @@ class XlaBuilder { // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. - tensorflow::gtl::FlatMap handle_to_index_; + absl::flat_hash_map handle_to_index_; // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of @@ -1037,7 +1038,7 @@ class XlaBuilder { std::map embedded_; // The unique parameter numbers. - tensorflow::gtl::FlatSet parameter_numbers_; + absl::flat_hash_set parameter_numbers_; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata @@ -1195,7 +1196,12 @@ class XlaBuilder { friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape); + absl::Span operands, const Shape& shape, + const string& opaque); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1246,6 +1252,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); friend XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups); @@ -1717,12 +1725,28 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); -// Enqueues a custom call instruction onto the computation. -// During code generation, a call instruction is emitted which targets a -// symbol with the name |call_target_name|. The |operands| are passed to the -// call instruction. |shape| is the resultant shape. +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape); + absl::Span operands, const Shape& shape, + const string& opaque = ""); + +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = ""); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1814,6 +1838,8 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -1976,7 +2002,7 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // the last dimension is chosen by default. // // If both keys and values are provided: -// * The keys and the values must tensors with the same dimensions. The +// * The keys and the values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and a tensor with their diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index a472747bd174e3bbd352f07f2ab092e678b81073..0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } +ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream( + stream_executor::Stream* stream) { + host_to_device_stream_ = stream; + return *this; +} + +stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const { + return host_to_device_stream_; +} + ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 416131be006e6ecddb47651f8b684c1d91df4892..ba3217f31b55bd1428f67da6154a46c8bc304053 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -65,6 +65,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; + // If set, this is the stream to perform any pre-computation transfers on. + // The platform of the stream must match the platform the executable was + // built for. A value of nullptr indicates the option has not been set. + ExecutableRunOptions& set_host_to_device_stream( + stream_executor::Stream* stream); + stream_executor::Stream* host_to_device_stream() const; + // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( @@ -90,6 +97,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + stream_executor::Stream* host_to_device_stream_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index d310335618ded7b581e6ed632223218585bb791f..19667b7ed9d47896afd9a82a41de7997538b089b 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; @@ -199,7 +205,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return Status::OK(); } - if (layout.format() == INVALID_FORMAT) { + if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", layout.ShortDebugString(), shape.ShortDebugString()); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index b78883c2d870043032306637730c4666665125a8..af032b1cae4c5645d6c7da55b779cd0a7336592e 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,10 @@ class LayoutUtil { static Layout MakeLayoutFromMajorToMinor( absl::Span major_to_minor); + // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // dimensions. + static Layout MakeDescendingLayout(int64 rank); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 0d3136b0cc6a3a695eacb98c16200e46a144c571..3ed3afcfcede20fbf5c7d4f004378817febeb4c7 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // regression. flags->set_xla_cpu_enable_fast_math(true); flags->set_xla_gpu_enable_fast_math(true); + + flags->set_xla_force_host_platform_device_count(1); } // Allocates flag_values and flag_objects; this function must not be called more @@ -323,6 +325,17 @@ void AllocateFlags() { flag_values->xla_gpu_crash_on_verification_failures(), "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), + tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for( + &DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many " + "host \"devices\". All of these host devices are backed by the same" + "threadpool. Setting this to anything other than 1 can increase " + "overhead from context switching but we let the user override this " + "behavior to help run tests on the host that run models in parallel " + "across multiple devices."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 5035f4198890857fcafd0156d7eaeeb4bc164322..656ce720a13d5c9622e9dc05ae04ddcac8cbeee5 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Literal literal(proto.shape()); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( @@ -725,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span start_indices, ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); + case PRED: + return SliceInternal(result_shape, start_indices); + case U8: + return SliceInternal(result_shape, start_indices); + case U16: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + case U64: + return SliceInternal(result_shape, start_indices); + case S8: + return SliceInternal(result_shape, start_indices); + case S16: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case S64: + return SliceInternal(result_shape, start_indices); + case F16: + return SliceInternal(result_shape, start_indices); case BF16: return SliceInternal(result_shape, start_indices); + case F32: + return SliceInternal(result_shape, start_indices); + case F64: + return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -1850,6 +1870,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + if (LayoutUtil::IsSparseArray(subshape())) { + // Compute the number of elements (indices) in the sparse shape and reserve + // the necessary space in spare_indices. + TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) + << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + << "Unexpected number of indices in proto (" + << proto.sparse_indices_size() << ") for shape of rank " + << ShapeUtil::Rank(subshape()); + const int64 index_count = + proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + sparse_indices()->Resize(index_count); + + // Copy the indices from the proto into the SparseIndexArray object. + TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(), + proto.sparse_indices())); + } + switch (subshape().element_type()) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); @@ -1907,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } } break; case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; + return InvalidArgument("Should not be called on tuple shapes: %s", + ShapeUtil::HumanString(subshape())); default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + return InvalidArgument("Is called on unsupported shape: %s", + ShapeUtil::HumanString(subshape())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1e0a2ad0ddf81d6813942c77ae273e2ce24e735e..3cd3541fe1596600b4f0b43e3011e1f0322ac8fe 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -203,6 +203,10 @@ class LiteralBase { // Returns the count of the elements in the array at the given shape index in // this literal. int64 element_count(const ShapeIndex& index = {}) const { + if (index.empty()) { + // Common case, avoid GetSubshape(). + return ShapeUtil::ElementsIn(shape()); + } return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } @@ -852,9 +856,9 @@ class BorrowingLiteral : public LiteralBase { template absl::Span LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " @@ -865,9 +869,9 @@ absl::Span LiteralBase::Piece::data() const { template absl::Span LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 7ad287c8973367fb04583e6911ff75e76bdf5f1e..dd5b54e4c99998f676419cf98a3da16593338829 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) { absl::Span(expected_indices.data(), expected_indices.num_elements())); EXPECT_EQ(literal.data(), absl::Span(expected_values)); + + // Serialize then deserialize and verify the resulting literal. + TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto, + Literal::CreateFromProto(literal.ToProto())); + + EXPECT_EQ(literal_from_proto.sparse_indices()->data(), + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal_from_proto.data(), + absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index 787725e884c810fd724ab88ad7d4beaf3e0a6cc7..b507a2ef79f1d7e9ae632744675dddf574490805 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { @@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) { return safe_file_name; } +std::pair>*> +GetDirectoryExpanders() { + static auto* mutex = new tensorflow::mutex; + static auto* singleton = new std::vector>; + return {mutex, singleton}; +} + +// Runs all the directory expanders over x and returns the result. +string Expand(string x) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + for (const auto& f : *pair.second) { + x = f(x); + } + return x; +} + } // namespace Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name) { tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string expanded_dir = Expand(directory); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); + const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name); return tensorflow::WriteBinaryProto(env, path, message); } +void RegisterDirectoryExpander(const std::function& expander) { + auto pair = GetDirectoryExpanders(); + tensorflow::mutex_lock lock(*pair.first); + pair.second->push_back(expander); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 3667621367c7639c40ff17aee7b77305d4d34e33..f22fc8b8499dd4a5329276040331a2ed9e89bea9 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, const string& directory, const string& file_name); +// Registers a function that may either expand a dirpath or forward the original +// dirpath along as-is. +void RegisterDirectoryExpander(const std::function& expander); + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 9da5dc0d2d40cb10640fb0fd2c4c65b4f8e55346..ffa336f30417eab1e4b16e278a97130c5fd57f88 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( absl::Span window_strides, absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, - lhs_dilation, rhs_dilation, dimension_numbers); + lhs_dilation, rhs_dilation, dimension_numbers, + feature_group_count); } LocalOp LocalComputationBuilder::ConvertElementType( @@ -530,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 1d5dfe591175735d58a5fe555fffc8043fa4de7e..43332e0abd410c08dc5a40f7de39dbc96d34a72c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -248,7 +248,8 @@ class LocalComputationBuilder { absl::Span > padding, absl::Span lhs_dilation, absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); @@ -277,6 +278,8 @@ class LocalComputationBuilder { const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index fa4366ff0789a3d05c26479a746a18dfcf7e902b..f8197488fb3bacb312cc7fbf149b773851992b8a 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -995,7 +995,30 @@ class ComputationBuilder(object): window_strides) return self._client.ReduceWindowWithGeneralPadding( operand, init_value, computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads) + window_dimensions, window_strides, (), (), pads) + + def ReduceWindowWithGeneralPadding( + self, operand, init_value, computation_to_apply, window_dimensions, + window_strides, base_dilations, window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + A LocalOp representing the added ReduceWindow op. + """ + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. @@ -1109,7 +1132,7 @@ class ComputationBuilder(object): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return self._client.DotGeneral(lhs, rhs, dimension_numbers) - def Conv(self, lhs, rhs, window_strides, padding): + def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): """Enqueues a Conv operation onto the computation. Args: @@ -1117,6 +1140,7 @@ class ComputationBuilder(object): rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the Conv operation. """ @@ -1125,10 +1149,11 @@ class ComputationBuilder(object): self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), - (), dimension_numbers) + (), dimension_numbers, + feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation): + lhs_dilation, rhs_dilation, feature_group_count=1): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: @@ -1138,6 +1163,7 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors. + feature_group_count: number of feature groups for grouped convolution. Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. @@ -1145,7 +1171,8 @@ class ComputationBuilder(object): dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1163,7 +1190,8 @@ class ComputationBuilder(object): return dimension_numbers def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers): + rhs_dilation, dimension_numbers, + feature_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. Args: @@ -1190,6 +1218,7 @@ class ComputationBuilder(object): labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not 'I' or 'O'. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the ConvGenralDilated operation. """ @@ -1215,7 +1244,8 @@ class ComputationBuilder(object): key=lambda i: rhs_spec.index(out_spec[i]))) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def Sort(self, operand, dimension=-1): """Enqueues a sort operation onto the computation.""" diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fd98e19457f61aade947aa354d2e415148d127f6..82103f03132e45ff822ce1ebcc2be47b24f5869f 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + feature_group_count = 2 + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]], + [[0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 97fcd37f6b89d6dd737c233ef19f55a8faa1b624..3abb3855a42b8b5222115262448d359da3a80e87 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -34,19 +34,28 @@ cc_library( ], ) -tf_cc_binary( - name = "grpc_service_main_cpu", +cc_library( + name = "grpc_service_main_library", srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings:str_format", ], ) +tf_cc_binary( + name = "grpc_service_main_cpu", + deps = [ + ":grpc_service_main_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + tf_cc_test( name = "grpc_client_test", srcs = ["grpc_client_test.cc"], diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index d6b5149a24c491d1e9d7cd9119b36d7eb2ad65d3..522ab99fb1feff69610af887b58f197211cdb21f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpcpp/server_builder.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,8 +30,15 @@ namespace { int RealMain(int argc, char** argv) { int32 port = 1685; + bool any_address = false; + string platform_str; std::vector flag_list = { - tensorflow::Flag("port", &port, "port to listen on"), + tensorflow::Flag("platform", &platform_str, + "The XLA platform this service should be bound to"), + tensorflow::Flag("port", &port, "The TCP port to listen on"), + tensorflow::Flag( + "any", &any_address, + "Whether to listen to any host address or simply localhost"), }; string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); + se::Platform* platform = nullptr; + if (!platform_str.empty()) { + platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie(); + } std::unique_ptr service = - xla::GRPCService::NewService().ConsumeValueOrDie(); + xla::GRPCService::NewService(platform).ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(absl::StrFormat("localhost:%d", port)); + string server_address( + absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port)); + builder.SetMaxReceiveMessageSize(INT_MAX); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); LOG(INFO) << "Server listening on " << server_address; server->Wait(); - return 0; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 17a557ccc3c0e069ea00be49829f634e64ff9533..f9f741aaee20ee19f62fbcab6ac05fd34dc06ad6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -146,6 +146,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -182,6 +184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -251,6 +254,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -296,6 +300,7 @@ cc_library( "hlo_opcode.cc", "hlo_schedule.cc", "hlo_sharding.cc", + "hlo_sharding_metadata.cc", ], hdrs = [ "dfs_hlo_visitor.h", @@ -309,6 +314,7 @@ cc_library( "hlo_opcode.h", "hlo_schedule.h", "hlo_sharding.h", + "hlo_sharding_metadata.h", ], deps = [ ":hlo_casting_utils", @@ -333,6 +339,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -365,8 +373,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/utility", ], ) @@ -392,6 +403,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], ) @@ -482,6 +494,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -590,6 +604,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -772,6 +787,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -899,6 +915,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -948,6 +965,8 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -983,6 +1002,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1030,6 +1051,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1083,6 +1106,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1121,6 +1145,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1142,10 +1168,43 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_module_group_metadata", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], @@ -1160,6 +1219,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -1180,6 +1240,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1224,6 +1286,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1244,6 +1308,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1258,16 +1323,26 @@ cc_library( ], ) +cc_library( + name = "fusion_queue", + hdrs = ["fusion_queue.h"], + deps = [ + ":hlo", + ], +) + cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":fusion_queue", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -1294,6 +1369,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1349,6 +1426,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -1604,6 +1682,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -1635,6 +1715,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1760,42 +1841,6 @@ tf_cc_test( ], ) -cc_library( - name = "inliner", - srcs = ["inliner.cc"], - hdrs = ["inliner.h"], - deps = [ - ":hlo", - ":hlo_pass", - ":hlo_query", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "inliner_test", - srcs = ["inliner_test.cc"], - deps = [ - ":cpu_plugin", - ":hlo", - ":hlo_matchers", - ":inliner", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", - ], -) - cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], @@ -2007,6 +2052,7 @@ cc_library( ":logical_buffer", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2042,6 +2088,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2063,6 +2110,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2146,6 +2194,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2167,6 +2216,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -2227,6 +2278,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2283,6 +2336,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2309,6 +2364,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2380,6 +2437,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2392,6 +2450,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2424,6 +2483,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2526,6 +2587,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_module_group", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2551,12 +2613,34 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) +tf_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_cse", srcs = ["hlo_cse.cc"], @@ -2570,6 +2654,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2644,26 +2729,12 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) -cc_library( - name = "hlo_sharding_metadata", - srcs = ["hlo_sharding_metadata.cc"], - hdrs = [ - "hlo_sharding_metadata.h", - ], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:shape_tree", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "hlo_domain_verifier", srcs = ["hlo_domain_verifier.cc"], @@ -2714,7 +2785,6 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -3057,6 +3127,7 @@ cc_library( ":buffer_assignment", ":hlo", ":hlo_proto", + ":hlo_verifier", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", ], @@ -3090,6 +3161,7 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3212,6 +3284,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3241,6 +3315,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3297,6 +3372,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -3324,7 +3401,6 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -3382,6 +3458,39 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "map_inliner", + srcs = ["map_inliner.cc"], + hdrs = ["map_inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "map_inliner_test", + srcs = ["map_inliner_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":map_inliner", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 5458159d149c627b1121fd8a30e073b712542390..ca71f2cc129fc5d14e454c98a6e5ebf2e94cd7d2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -745,12 +745,25 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( } const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { return hlo; } - return computation_->AddInstruction( - HloInstruction::CreateReshape(dot->shape(), hlo)); + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + }; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + hlo = as_type(hlo, dot->shape().element_type()); + if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + hlo = computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + } + return hlo; + }; + + auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { + return AddReduce(as_type(hlo, F32), dim); }; auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, @@ -770,7 +783,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), Flatten(rhs)), 0)))); return true; } @@ -804,17 +817,17 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, - reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(Flatten(lhs), rhs), 0)))); return true; } TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary( - AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), - rhs_collapsing_dim), - rhs), - rhs_collapsing_dim)))); + dot, reshape_if_necessary(add_reduce_in_f32( + multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); return true; } @@ -826,7 +839,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, reshape_if_necessary(AddReduce( + dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), lhs_collapsing_dim)), lhs_collapsing_dim)))); @@ -1061,7 +1074,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); - auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto memoized_shape = + ShapeUtil::MakeShape(dot->shape().element_type(), {m, n}); auto* memoized_inst = computation_->AddInstruction( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); @@ -1109,10 +1123,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if ((dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) || + ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || + ShapeUtil::Rank(dot->shape()) > 2) { return Status::OK(); } @@ -2041,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr @@ -2187,7 +2209,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } // If it is key/value sort, the output of sort is a tuple. return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + sort, HloInstruction::CreateTuple(sort->operands())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index b864c372fa5877ca329d2efbbf7d747c763ae2c0..9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloPassInterface { +class AlgebraicSimplifier : public HloModulePass { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3fc1ba24271b40de0a24ed4c957cd83aca736f55..42d1f337dc22b91dcef4eb8ed4c0c57c6febeb70 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2133,16 +2133,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto values = builder.AddInstruction( - HloInstruction::CreateParameter(1, values_shape, "values")); + auto values0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values0")); + auto values1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, values_shape, "values1")); builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, + keys, {values0, values1})); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(keys, values0, values1)); } // Used for TEST_Ps that test merging (or not) of a kPad instruction into a @@ -3233,17 +3237,18 @@ INSTANTIATE_TEST_CASE_P( class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< - ::testing::tuple> {}; + ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { int m, k, n; bool transpose_lhs, transpose_rhs; - std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + PrimitiveType element_type; + std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam(); - Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); - Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); - Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -3285,7 +3290,7 @@ INSTANTIATE_TEST_CASE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool())); + ::testing::Bool(), ::testing::Values(F32, BF16))); struct DotOfConcatTestSpec { int64 m; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f..ef5e211646e7b0b66b8e6c09948be58063422943 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -176,13 +176,13 @@ StatusOr> AllocationTracker::DeconstructTuple( } StatusOr> AllocationTracker::Resolve( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } StatusOr AllocationTracker::ResolveForReplica( - const GlobalDataHandle& data, int replica_id) { + const GlobalDataHandle& data, int replica_id) const { tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, ResolveInternal(data)); @@ -196,7 +196,7 @@ StatusOr AllocationTracker::ResolveForReplica( } StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index a7d8927cf7e90d764ff8046df16c71922b11478e..98d1a302a9f66f4a00e05d62837a79133e222687 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -64,13 +65,13 @@ class AllocationTracker { // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). StatusOr> Resolve( - const GlobalDataHandle& data); + const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id); + int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -86,7 +87,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. StatusOr> ResolveInternal( - const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If @@ -110,9 +111,9 @@ class AllocationTracker { // A map from device memory opaque value to allocation. One such map is // maintained per device ordinal. - using AllocationMap = tensorflow::gtl::FlatMap; + using AllocationMap = absl::flat_hash_map; - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; // Backend to use with this tracker. The backend supplies the memory allocator // to use when deallocating memory. @@ -123,10 +124,7 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - // - // This is not a TF FlatMap because (currently) FlatMap (and therefore - // AllocationMap) is not movable. - std::unordered_map opaque_to_allocation_map_ + absl::flat_hash_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the @@ -146,7 +144,7 @@ class AllocationTracker { // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then // free'd when both the view *and* the original tuple are Unregistered. This // refcounting is managed in opaque_to_allocation_map_. - tensorflow::gtl::FlatMap>> + absl::flat_hash_map>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index 79d37f08d3553321ebbabc44c8f2488b194954d5..5b625bf3b98b060531532f07de343f7ca4f09ac9 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -25,7 +25,7 @@ namespace xla { // Normally these would live in the algebraic simplifier, but we want to run // this to fixpoint (this pass reaches fixed point in one execution) before we // run the DotDecomposer. -class BatchDotSimplification : public HloPassInterface { +class BatchDotSimplification : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 30d33e0d3531bb5e931ebfa0b60c91523dd0cb44..f70f6ddfec69c0113a1afe2073a2392098f49456 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 76e32174f3ee7d319df6f1f465e19d265d5330f2..147f3ae7b6d4ed0d4dadfb136e1e0f0bf3ae90c6 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -26,7 +26,7 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormExpander : public HloPassInterface { +class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index 5dcd31b83d24f836d31f44181f39cb8371ca1033..cb3d12f0bfd0e502136ce39660e091dc1c3879be 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -31,7 +31,7 @@ namespace xla { // optimization pipeline followed by a DCE pass. If other passes are needed // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the // changed made by this pass. -class BFloat16ConversionFolding : public HloPassInterface { +class BFloat16ConversionFolding : public HloModulePass { public: explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 30b6346312790f0a199f96f1956ba9ce3e617f72..f48e925823cf02bf4351b9bc7741123f5b1cd06f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -25,7 +25,7 @@ namespace xla { // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not // support BF16 input/output or mixed precision, according to the passed-in // backend-specific BF16 support rules. -class BFloat16Normalization : public HloPassInterface { +class BFloat16Normalization : public HloModulePass { public: explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} @@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface { // use mixed precision; it removes mixed precision even if the backend supports // it. This pass is used to make the HLO module valid for other HLO passes which // do not support mixed precision. -class BFloat16MixedPrecisionRemoval : public HloPassInterface { +class BFloat16MixedPrecisionRemoval : public HloModulePass { public: BFloat16MixedPrecisionRemoval() {} diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cef0eba14e9dd463d6c32b047211bf25a84478f6..2411fdcb2089c234d2b4fa3db9498d7f6b3a40ad 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -284,7 +284,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction::CreateParameter(1, s32_shape, "value")); HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 58f78f8e24d0bc00a63e3583828cf8e01ae4531a..002be9c97098ef1f73446c458dae24bbc826a626 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( }; auto root = fusion->fused_instructions_computation()->root_instruction(); - tensorflow::gtl::FlatSet changed_root_buffers; + absl::flat_hash_set changed_root_buffers; auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { @@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations) { + absl::flat_hash_set* visited_computations) { bool parameter_changed = false; auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse @@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( // another input parameter. A fixed point will be reached because the // parameters can only be changed from BF16 to F32, not the other way // around. - tensorflow::gtl::FlatSet visited_in_while; + absl::flat_hash_set visited_in_while; while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), &visited_in_while) || ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), @@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { const auto& computations_topological_order = module->MakeComputationPostOrder(); - tensorflow::gtl::FlatSet resolved; + absl::flat_hash_set resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { if (ContainsKey(resolved, *comp_it)) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1ee64971ab53e1775294afde1c779369a838008a..5fcaa15c8356107af02e9099874a293d8350c51a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,7 +60,7 @@ namespace xla { // BFloat16ConversionFolding. If other passes are needed after this pass, run // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this // pass. -class BFloat16Propagation : public HloPassInterface { +class BFloat16Propagation : public HloModulePass { public: explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); @@ -81,7 +83,7 @@ class BFloat16Propagation : public HloPassInterface { // The set of instructions to consider using bfloat16, computed in the forward // pass. - tensorflow::gtl::FlatSet consider_using_bfloat16_; + absl::flat_hash_set consider_using_bfloat16_; // *************************** // Functions called and state produced by the backward pass (from root to @@ -110,12 +112,12 @@ class BFloat16Propagation : public HloPassInterface { // The set of HloInstructions that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set instructions_visited_in_backward_pass_; // The set of HloComputations that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet + absl::flat_hash_set computations_visited_in_backward_pass_; // *************************** @@ -131,7 +133,7 @@ class BFloat16Propagation : public HloPassInterface { // point is reached. bool ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet* visited_computations); + absl::flat_hash_set* visited_computations); // Makes the parameters of called computations match how they are called by // the given HLO. @@ -182,11 +184,11 @@ class BFloat16Propagation : public HloPassInterface { PrimitiveType target_type); // The set of F32 HLO values that must be kept in F32. - tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + absl::flat_hash_set values_that_must_be_kept_as_f32_; // Mapping from each HloComputation to the number of callers to it in the // module. Populated at the beginning of this pass. - tensorflow::gtl::FlatMap caller_counts_; + absl::flat_hash_map caller_counts_; // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which // are subject to further adjustment, then finally applied to the HLOs. This @@ -195,8 +197,7 @@ class BFloat16Propagation : public HloPassInterface { // // For each HloInstruction, changes_to_bf16_ stores the affected buffers in // the output as a map from in-place pointers to subshapes to shape indices. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 23645346e6f491beb5171cc839c013ce5f83d789..5b48f10505e78c035608d4c575501e4623218987 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: + case HloOpcode::kCollectivePermute: case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 65fa951afe3e60652413206913640af38f5bb824..2c2d1626c2c0d5d4b13e401dad9fd6c51514fc13 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -41,10 +43,10 @@ limitations under the License. namespace xla { namespace { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::StrAppend; using absl::StrAppendFormat; -using ::tensorflow::gtl::FlatMap; -using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::HumanReadableNumBytes; template @@ -128,8 +130,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + flat_hash_set thread_local_set; + flat_hash_set global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -444,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { using SliceSet = - FlatSet; + flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -519,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - FlatMap + flat_hash_map combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations @@ -582,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() { } // Update allocation indices to their new positions. - allocation_index_for_buffer_.clear_no_resize(); + allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(), + allocation_index_for_buffer_.end()); for (size_t index = 0; index < allocations_.size(); ++index) { BufferAllocation* allocation = &allocations_[index]; allocation->set_index(index); @@ -812,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet& colocated_buffers, - const FlatSet& colocated_allocations, - FlatMap>* + const flat_hash_set& colocated_buffers, + const flat_hash_set& colocated_allocations, + flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of @@ -833,7 +837,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - FlatMap post_order_position; + flat_hash_map post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -850,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation( // buffers_to_assign_sequentially map, even if we end up with an empty set // of buffers. This ensures we can correctly determine whether to run // whole-module heap simulation. - buffers_to_assign_sequentially->emplace(computation, - FlatSet()); + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set()); } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers @@ -1043,12 +1047,12 @@ Status BufferAssigner::AssignBuffersForComputation( return Status::OK(); } -FlatMap, - LogicalBuffer::Color::Hasher> +flat_hash_map, + LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const FlatSet& buffers) { - FlatMap, - LogicalBuffer::Color::Hasher> + const flat_hash_set& buffers) { + flat_hash_map, + LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { color_map[buffer->color()].insert(buffer); @@ -1057,23 +1061,38 @@ BufferAssigner::SplitBuffersByColor( } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const FlatMap>& + const flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + + // Returns a heap algorithm that chooses the best result from several + // algorithms. + auto get_heap_algorithm = [&](int64 alignment) { + auto algorithms = + absl::make_unique>>(); + algorithms->push_back(absl::make_unique( + absl::make_unique(alignment))); + algorithms->push_back( + absl::make_unique(alignment)); + return absl::make_unique(std::move(algorithms)); + }; + if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; HloSchedule schedule(&assignment->module()); - FlatSet all_buffers_to_assign; + flat_hash_set all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1093,8 +1112,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique( - absl::make_unique(alignment)), + HeapSimulator::Run(get_heap_algorithm(alignment), assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1108,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet& buffers_to_assign = pair.second; + const flat_hash_set& buffers_to_assign = + pair.second; const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1123,12 +1142,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run( - absl::make_unique( - absl::make_unique(alignment)), - *computation, HloInstructionSequence(*instruction_sequence), - assignment->points_to_analysis(), assignment->buffer_size_, - options)); + HeapSimulator::Run(get_heap_algorithm(alignment), *computation, + HloInstructionSequence(*instruction_sequence), + assignment->points_to_analysis(), + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1145,9 +1162,8 @@ std::vector ComputePeakMemoryLogicalBuffers( const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical // buffers in this allocation. - tensorflow::gtl::FlatMap - id_to_buffer; - tensorflow::gtl::FlatMap buffer_sizes; + absl::flat_hash_map id_to_buffer; + absl::flat_hash_map buffer_sizes; for (const auto& pair : allocation.assigned_buffers()) { const LogicalBuffer* buffer = pair.first; const BufferAllocation::OffsetSize& offset_size = pair.second; @@ -1186,7 +1202,7 @@ std::vector ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - tensorflow::gtl::FlatSet live_buffers; + absl::flat_hash_set live_buffers; live_size = 0; for (const auto& event : heap_trace.events()) { const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); @@ -1576,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - FlatSet* colocated_buffers, - FlatSet* colocated_allocations) { + flat_hash_set* colocated_buffers, + flat_hash_set* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry @@ -1650,8 +1666,8 @@ StatusOr> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - FlatSet colocated_buffers; - FlatSet colocated_allocations; + flat_hash_set colocated_buffers; + flat_hash_set colocated_allocations; std::vector colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); @@ -1669,7 +1685,7 @@ StatusOr> BufferAssigner::CreateAssignment( // First assign buffers for global computatations. Temporary buffers for // sequential computations are collected in 'buffers_to_assign_sequentially'. - FlatMap> + flat_hash_map> buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 24ba7c16f548c10f58f41d2b88488939ca2d8e4d..899cd36e1f98c9e7b8ba7e42c06ced5c3e8afcc8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -33,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -148,7 +148,7 @@ class BufferAllocation { // Access to the logical buffers assigned to this allocation, and their // associated logical offsets and sizes. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& assigned_buffers() const { return assigned_buffers_; } @@ -323,7 +323,7 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. - tensorflow::gtl::FlatMap assigned_buffers_; + absl::flat_hash_map assigned_buffers_; int64 fragmentation_bytes_ = 0; std::vector heap_traces_; @@ -500,7 +500,7 @@ class BufferAssignment { int64 temp_allocation_total_size_ = 0; // Maps Buffers to the index of the BufferAllocation which holds the buffer. - tensorflow::gtl::FlatMap + absl::flat_hash_map allocation_index_for_buffer_; const HloModule* module_; @@ -554,11 +554,10 @@ class BufferAssigner { // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, - tensorflow::gtl::FlatMap>* + const absl::flat_hash_set& colocated_buffers, + const absl::flat_hash_set& colocated_allocations, + absl::flat_hash_map>* buffers_to_assign_sequentially, BufferAssignment* assignment); @@ -568,9 +567,8 @@ class BufferAssigner { // 'run_whole_module_heap_simulation' is true, the heap simulation will be run // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const tensorflow::gtl::FlatMap< - const HloComputation*, - tensorflow::gtl::FlatSet>& + const absl::flat_hash_map>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -590,7 +588,7 @@ class BufferAssigner { // alias. Explicitly handling these colocated buffers is necessary because // points-to analysis is computation level scope and does not recognize // aliasing across computations (b/32491382). - using ColocatedBufferSet = tensorflow::gtl::FlatSet; + using ColocatedBufferSet = absl::flat_hash_set; // Returns a vector of ColocatedBufferSet objects, where each // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' @@ -605,8 +603,8 @@ class BufferAssigner { void AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations); + absl::flat_hash_set* colocated_buffers, + absl::flat_hash_set* colocated_allocations); // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // the invariant that all sets in 'colocated_buffer_sets' are disjoint. @@ -624,11 +622,10 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. - tensorflow::gtl::FlatMap, - LogicalBuffer::Color::Hasher> - SplitBuffersByColor( - const tensorflow::gtl::FlatSet& buffers); + absl::flat_hash_map, + LogicalBuffer::Color::Hasher> + SplitBuffersByColor(const absl::flat_hash_set& buffers); // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index cdd3cf4032ef6916086e1c2d148b575192503000..f939a426ead7c34092fc5234ef779ee857347a26 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -27,8 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -102,7 +101,7 @@ class BufferLiveness { // Set of LogicalBuffers which are aliased in the output of other // instructions. For example, a LogicalBuffer which is inserted into a tuple // is considered to be aliased and will be in this set. - tensorflow::gtl::FlatSet aliased_buffers_; + absl::flat_hash_set aliased_buffers_; // LogicalBuffers that may be live out of the entry computation. PointsToSet::BufferSet maybe_live_out_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h index 305914fca828f110bf54239bddb1590172562b16..cc46af5eeec623e19637cd6245915b3a3124a2cd 100644 --- a/tensorflow/compiler/xla/service/buffer_value_containers.h +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet( return output; } -using BufferValueFlatSet = tensorflow::gtl::FlatSet; +using BufferValueFlatSet = absl::flat_hash_set; template BufferValueFlatSet ToBufferValueFlatSet( const LogicalBufferContainerT& logical_buffer_container) { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 23b2a327096dfdb3c756a4acc5476ec01dcac1b3..bdd5069632e84fe6c67ca129f726432479ac1b35 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { bool CallGraph::DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { if (a == b || ContainsKey(*visited, b)) { // The call graph is guaranteed to be acyclic so any previously visited node // we encounter was already determined to be dominated. @@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper( bool CallGraph::Dominates(const HloComputation* a, const HloComputation* b) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; return DominatesHelper(a, b, &visited); } @@ -277,7 +278,7 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const { + absl::flat_hash_set* visited) const { auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 3af2ab5edfd9faf4ac5193df4b823c21b55b2f7f..cb56f4789d06ac33acdaadc8b619b9e37f683d58 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -20,11 +20,11 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -145,19 +145,19 @@ class CallGraphNode { // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callees_; - tensorflow::gtl::FlatSet callee_set_; + absl::flat_hash_set callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector callers_; - tensorflow::gtl::FlatSet caller_set_; + absl::flat_hash_set caller_set_; // The call sites in this computation std::vector callsites_; // The map from instruction to index in callsites_ for looking up the callsite // (if any) associated with a particular instruction in this computation. - tensorflow::gtl::FlatMap callsite_instructions_; + absl::flat_hash_map callsite_instructions_; // The call sites in other computations which call this computation. std::vector caller_callsites_; @@ -250,14 +250,14 @@ class CallGraph { // 'visited'. Status VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet* visited) const; + absl::flat_hash_set* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; @@ -267,7 +267,7 @@ class CallGraph { // Map from HLO computation to the index of the corresponding call graph node // in nodes_. - tensorflow::gtl::FlatMap node_indices_; + absl::flat_hash_map node_indices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c5cd88b9ea2a9c308786d4d7476316b1e592d40a..08c4aff4f7fc7fc332fc7f34ece019eb57d71f3a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -25,7 +25,7 @@ namespace xla { // For every kCall operation in the main computation, we inline the body of the // called function, and proceed recursively. -class CallInliner : public HloPassInterface { +class CallInliner : public HloModulePass { public: using InlinedInstructionMap = std::unordered_map; diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index e5a6c28478a7ebf87878c3937069f15cafe12615..96bd2616f5607de888a096f8392ceb68490276e3 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 3de50cbd7ff752e8722a103b68f75144c6c889cd..2223ad67534dc31fc2c56ce68bdc87e881f20f32 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that removes kConditional with a constant predicate, replacing them // with their true or false computation as appropriate. -class ConditionalSimplifier : public HloPassInterface { +class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index 498894737fa37a6d8cca6ead2a86c72eb84ababd..ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -25,7 +25,7 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloPassInterface { +class ConvolutionFeatureGroupConverter : public HloModulePass { public: ConvolutionFeatureGroupConverter() {} diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b65dfef9c9575b683b2656af2ccc151d87db2cd7..f35324aa35370b592871749cba9fc2f66bea9219 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -432,7 +432,7 @@ class CopyRemover { // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap value_to_node; + absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate @@ -480,7 +480,7 @@ class CopyRemover { // respective ValueNode representing that value. void AddValueList( absl::Span values, - tensorflow::gtl::FlatMap* value_to_node) { + absl::flat_hash_map* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; for (const HloValue* value : values) { @@ -516,8 +516,7 @@ class CopyRemover { // respective ValueNode. void CreateCopyMap( const HloModule& module, - const tensorflow::gtl::FlatMap& - value_to_node) { + const absl::flat_hash_map& value_to_node) { for (HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with @@ -905,7 +904,7 @@ class CopyRemover { // The heads of all the value lists. Each value list represents the HLO // values contained in a particular HLO buffer. The values in the list are // in dependency order. - tensorflow::gtl::FlatSet value_lists_; + absl::flat_hash_set value_lists_; // Copy removal requires fast access to the value list elements // corresponding to the source and destination values of the kCopy @@ -916,7 +915,7 @@ class CopyRemover { ValueNode* src = nullptr; ValueNode* dest = nullptr; }; - tensorflow::gtl::FlatMap copy_map_; + absl::flat_hash_map copy_map_; }; HloModule* module_; @@ -1010,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector buffers_at_index = diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index d308f6bc84670b78b9cab476f2893bce267df2cf..c097089e30d59936a32f69c49123c398f1611ea3 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -43,7 +43,7 @@ namespace xla { // (3) The buffer set of the root instruction of the entry computation must be // unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and // InstructionAliasSet::IsDistinct return true. -class CopyInsertion : public HloPassInterface { +class CopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8cc522a59e9805ec86e9e69c8d6e5fa1a3ab682d..58abb330a6e31e9b7a8081cd7964cf89a5b64a09 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -93,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -126,7 +128,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", @@ -180,6 +181,7 @@ cc_library( ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", + ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -288,6 +290,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -307,6 +311,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@llvm//:analysis", "@llvm//:target", ], @@ -461,12 +466,16 @@ cc_library( ], copts = runtime_copts(), deps = [ + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -623,6 +632,18 @@ cc_library( ], ) +cc_library( + name = "runtime_key_value_sort", + srcs = ["runtime_key_value_sort.cc"], + hdrs = ["runtime_key_value_sort.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], @@ -745,6 +766,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 59437e88af27528654a0af86baf69ec7a1e91d60..becee3f81fc34c73040d53e4a261bc3a656cd78c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -31,7 +31,7 @@ namespace cpu { // called canonical convolutions). This pass expands non-canonical convolutions // into reshapes and canonical convolutions, so that these non-canonical // convolutions can run faster. -class ConvCanonicalization : public HloPassInterface { +class ConvCanonicalization : public HloModulePass { public: explicit ConvCanonicalization( const TargetMachineFeatures* target_machine_features) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 18fc144efe0023c0893adfcb16eda3341c0938d3..68c715a086af2a53acd510d51479b29e2eeac632 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,8 +86,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" @@ -249,9 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding - // where we will take this pass in future. - // pipeline.AddPass(); + pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. @@ -308,7 +306,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), target_machine_features); + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -328,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - pass.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index d49f7d7cc2d9b1d00847feda62fa62dd740820d8..076235f8874b5de57075fb690dd1b9111b6838a6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -30,7 +30,7 @@ namespace xla { // // TODO(b/62548313): Remove this when buffer assignment is smarter // (module-scoped). -class CpuCopyInsertion : public HloPassInterface { +class CpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 6af724b2a5d71b9c30f3485ffb7e51d1d201cb6b..a39a9d4724655370454d60fbb7b474f223bd8a85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the CPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloPassInterface { +class CpuHloSupportChecker : public HloModulePass { public: CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index bfecbd6e017893e4f6d3dcbc01d46c899e6060fa..c291bf2d1ba2eaff4192051840768c037bece86f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" @@ -38,7 +39,7 @@ using absl::nullopt; using absl::optional; using ShouldMakeOperandColMajorCache = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; } // namespace static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 3c4fe68b830d9602f009b318d4e51e9a04a27e09..f4da35dd373f24d81323d198582048e2e6d36268 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, const TargetMachineFeatures* target_machine_features) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 4668f3872dad598edf4c7680e1b601622104ab3e..97659b88a7974d7caf91ab0d4741f3635e4dae4a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -321,8 +322,9 @@ static StatusOr RunDotOutputFusion( [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8a44c384bb0fe6f132c352ca8bd78baa23d093d4..a9febe891b5e9d1eb9e6b297952b50d1d26a3396 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,19 +17,29 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace cpu { namespace runtime { -XfeedManager* GetXfeedManager() { - static XfeedManager* manager = new XfeedManager; - return manager; +XfeedManager* GetXfeedManager(int device_ordinal) { + static auto* managers = new absl::flat_hash_map(); + static absl::Mutex* mutex = new absl::Mutex(); + + absl::MutexLock lock(mutex); + auto it = managers->find(device_ordinal); + if (it == managers->end()) { + it = managers->emplace(device_ordinal, new XfeedManager()).first; + } + return it->second; } extern const char* const kEigenMatMulF16SymbolName = @@ -74,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; +extern const char* const kKeyValueSortPREDSymbolName = + "__xla_cpu_runtime_KeyValueSortPRED"; +extern const char* const kKeyValueSortS8SymbolName = + "__xla_cpu_runtime_KeyValueSortS8"; +extern const char* const kKeyValueSortU8SymbolName = + "__xla_cpu_runtime_KeyValueSortU8"; +extern const char* const kKeyValueSortS16SymbolName = + "__xla_cpu_runtime_KeyValueSortS16"; +extern const char* const kKeyValueSortU16SymbolName = + "__xla_cpu_runtime_KeyValueSortU16"; +extern const char* const kKeyValueSortF16SymbolName = + "__xla_cpu_runtime_KeyValueSortF16"; +extern const char* const kKeyValueSortS32SymbolName = + "__xla_cpu_runtime_KeyValueSortS32"; +extern const char* const kKeyValueSortU32SymbolName = + "__xla_cpu_runtime_KeyValueSortU32"; +extern const char* const kKeyValueSortF32SymbolName = + "__xla_cpu_runtime_KeyValueSortF32"; +extern const char* const kKeyValueSortS64SymbolName = + "__xla_cpu_runtime_KeyValueSortS64"; +extern const char* const kKeyValueSortU64SymbolName = + "__xla_cpu_runtime_KeyValueSortU64"; +extern const char* const kKeyValueSortF64SymbolName = + "__xla_cpu_runtime_KeyValueSortF64"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime @@ -94,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireInfeedBufferForDequeue: " - << ShapeString(shape, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireInfeedBufferForDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireInfeedBufferForDequeue: " + << ShapeString(shape, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->infeed()->BlockingDequeueBuffer(); @@ -114,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseInfeedBufferAfterDeque: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, @@ -130,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireOutfeedBufferForPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireOutfeedBufferForPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->outfeed()->BlockingDequeueBuffer(); @@ -150,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index aa0e96712302e806a389c6ad05a2c1b6634ef901..b2e760a224ad8eaa61dae57b0f9cece04a7e54ae 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,6 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" @@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; +extern const char* const kKeyValueSortPREDSymbolName; +extern const char* const kKeyValueSortS8SymbolName; +extern const char* const kKeyValueSortU8SymbolName; +extern const char* const kKeyValueSortS16SymbolName; +extern const char* const kKeyValueSortU16SymbolName; +extern const char* const kKeyValueSortF16SymbolName; +extern const char* const kKeyValueSortS32SymbolName; +extern const char* const kKeyValueSortU32SymbolName; +extern const char* const kKeyValueSortF32SymbolName; +extern const char* const kKeyValueSortS64SymbolName; +extern const char* const kKeyValueSortU64SymbolName; +extern const char* const kKeyValueSortF64SymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. extern const char* const kXlaCpuRuntimeSymbolNamePrefix; -// Returns the infeed manager used by the CPU runtime. -XfeedManager* GetXfeedManager(); +// Returns the infeed manager used by the CPU runtime for the CPU device +// `device_ordinal`. Note the device ordinal does not name a CPU +XfeedManager* GetXfeedManager(int device_ordinal); } // namespace runtime } // namespace cpu @@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager(); extern "C" { +// Some things common to all of the runtime entry points below: +// +// * The shape pointer and shape_length reflect values that can be deserialized +// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass +// reified type information from the generated program to the runtime, which +// helps check the type safety and contract for the emitted-code/runtime +// communication. +// +// * run_options is used to look up the device ordinal for the stream executor +// we're executing under. If it is null the device ordinal is assumed to be +// 0 (this behavior helps in writing tests). + // Note: in the runtime entry points below, the shape pointer and shape_length // reflect values that can be deserialized via // llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified @@ -89,7 +115,8 @@ extern "C" { // the length would be more exact, but the length check is chosen as a // tradeoff between error checking and speed/simplicity. extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length, const void* shape, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length); // Relinquishes the next infeed buffer that was returned by // __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call @@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // implemented we will add support for multiple outstanding buffers // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); // Blocks until the next outfeed buffer is available to be populated, then // returns it. extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length); // Relinquishes the outfeed buffer after it has been populated. // buffer_ptr must have been previously returned by @@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( // acquired, i.e., there may only be one outstanding outfeed buffer in // use by the runtime. extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); } // extern "C" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 5519a43b2f6bc3a7df9a58823e43fae42f7f94df..1cc2844470376ceb61601f6d1361def84eac5b45 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { @@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed( buffers.push_back(buffer); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers); cleanup.release(); @@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer}); return Status::OK(); @@ -265,7 +268,8 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( buffer_pointers.push_back(b.get()); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers); VLOG(2) << "Waiting for buffer to be notified as populated."; std::vector outfed_shapes; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index df8c2a636bbda52e3a8df00015ce3f27e6ba1aea..b2abdb39a598871a7cc44760e464f48b9a200874 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -67,8 +69,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value * shape_ptr, llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - // The signature of the acquire infeed buffer function is: - // - // (void*)(int32 length); llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( - i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, + i8_ptr_type, + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* acquire_func; @@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, } acquire_func->setCallingConv(llvm::CallingConv::C); - // The signature of the release infeed buffer function is: - // - // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, + /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = - Call(acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = Call( + acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. @@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, /*SrcAlign=*/1, length_32); } - Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr, - b_.getInt32(shape_length)}); + Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } @@ -495,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { } Status IrEmitter::HandleSort(HloInstruction* sort) { - // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not implemented on CPU."); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto keys_destination_address = + EmitBufferPointer(keys_destination, keys->shape()); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + llvm::Value* values_destination_address = nullptr; + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + if (keys_destination != GetAllocationSlice(*keys)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); + auto source_buffer = GetEmittedValueFor(keys); + int64 keys_size = ByteSizeOf(keys->shape()); + MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, keys_size); + } + if (values != nullptr) { + values_destination_address = + EmitBufferPointer(values_destination, values->shape()); + if (values_destination != GetAllocationSlice(*values)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); + auto source_buffer = GetEmittedValueFor(values); + int64 values_size = ByteSizeOf(values->shape()); + MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, values_size); + } + } + + // Normalize the shape and the dimension to sort. + Shape normalized_keys_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + keys->shape()); + int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( + keys->shape().layout())[sort->dimensions(0)]; + + int64 sort_dimension_elements = + normalized_keys_shape.dimensions(physical_dimension_to_sort); + int64 higher_dimensions = 1; + for (int64 i = 0; i < physical_dimension_to_sort; ++i) { + higher_dimensions *= normalized_keys_shape.dimensions(i); + } + int64 lower_dimensions = 1; + for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + i > physical_dimension_to_sort; --i) { + lower_dimensions *= normalized_keys_shape.dimensions(i); + } + + PrimitiveType keys_type = keys->shape().element_type(); + const char* fn_name = nullptr; + llvm::Type* keys_native_type = nullptr; + switch (keys_type) { + case PRED: + fn_name = runtime::kKeyValueSortPREDSymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S8: + fn_name = runtime::kKeyValueSortS8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case U8: + fn_name = runtime::kKeyValueSortU8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S16: + fn_name = runtime::kKeyValueSortS16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case U16: + fn_name = runtime::kKeyValueSortU16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case F16: + fn_name = runtime::kKeyValueSortF16SymbolName; + keys_native_type = b_.getHalfTy()->getPointerTo(); + break; + case S32: + fn_name = runtime::kKeyValueSortS32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case U32: + fn_name = runtime::kKeyValueSortU32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case F32: + fn_name = runtime::kKeyValueSortF32SymbolName; + keys_native_type = b_.getFloatTy()->getPointerTo(); + break; + case S64: + fn_name = runtime::kKeyValueSortS64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case U64: + fn_name = runtime::kKeyValueSortU64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case F64: + fn_name = runtime::kKeyValueSortF64SymbolName; + keys_native_type = b_.getDoubleTy()->getPointerTo(); + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } + + llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( + b_.getVoidTy(), + {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + b_.getInt8PtrTy(), b_.getInt32Ty()}, + /*isVarArg=*/false); + auto* key_value_sort_func = llvm::cast( + module_->getOrInsertFunction(fn_name, key_value_sort_type)); + key_value_sort_func->setCallingConv(llvm::CallingConv::C); + key_value_sort_func->setDoesNotThrow(); + key_value_sort_func->setOnlyAccessesArgMemory(); + Call(key_value_sort_func, + {PointerCast(keys_destination_address, keys_native_type), + b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), + values != nullptr + ? PointerCast(values_destination_address, b_.getInt8PtrTy()) + : llvm::Constant::getNullValue(b_.getInt8PtrTy()), + b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( + values->shape().element_type()) + : 0)}); + + if (values != nullptr) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), + {keys_destination_address, values_destination_address}, + &b_, module_); + } + return Status::OK(); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -547,8 +688,25 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to @@ -587,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -1257,10 +1409,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains // [0->0, 3->1]. - gtl::FlatMap unreduced_dim_map; + absl::flat_hash_map unreduced_dim_map; - gtl::FlatSet reduced_dims(reduce.dimensions().begin(), - reduce.dimensions().end()); + absl::flat_hash_set reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); const Shape& operand_shape = reduce.operand(0)->shape(); const Shape& result_shape = reduce.shape(); @@ -1836,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - gtl::FlatSet inner_dims; + absl::flat_hash_set inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3df99464ba1103488b9fe054593740ada108d3da..586f27b104ed706a3b128903c6a90abbf3667e59 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/Triple.h" @@ -47,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); + } + private: // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); @@ -421,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Maps the buffer allocation slices for the parameters to the computation // being compiled to their parameter numbers. Only relevant for thread local // computations. - tensorflow::gtl::FlatMap + absl::flat_hash_map computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. @@ -561,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, } }; - tensorflow::gtl::FlatMap + absl::flat_hash_map emitted_literals_; - tensorflow::gtl::FlatMap + absl::flat_hash_map constant_buffer_to_global_; std::vector thread_local_computations_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index b4c0c09ec06bac9b5e228428c072948afdd4a547..ede7f433ca6b2cc5629115f800348be9dfb2b93b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + opcode == HloOpcode::kSort || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index a99cd99c14abb66fc426c43656520e01f34a1700..3822d5300e30704f68b2cf0c7f0b77d595c17a25 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -60,7 +60,7 @@ class ParallelTaskAssignment { // own embedded computation, which is compiled as a parallel compute function, // and which is invoked from a kCall instruction that is lowered in codegen to // a runtime parallel fork/join call. -class ParallelTaskAssigner : public HloPassInterface { +class ParallelTaskAssigner : public HloModulePass { public: // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0e7deb98e579c090c8fae320a3ba8a3ce8dbe5f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -0,0 +1,236 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" + +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace { +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +template +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements); +} + +// For floating point numbers, we want a total order comparator. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. Also we want to have a stable sort, so if the keys are the +// same, we compare the index values. +template +bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(lhs); + bool rhs_nan = std::isnan(rhs); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; + } + if (lhs != rhs) { + return lhs < rhs; + } + return lhs_index < rhs_index; +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair* row_to_sort, + int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), lhs.second, + Eigen::half_impl::half_to_float(rhs.first), rhs.second); + }); +} + +template +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + // High-level idea of the iteration/sorting logic: + // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the + // dimension to sort, c is the product of the more minor dimensions (set to 1 + // if b is the most minor dimension), and a is the product of the more major + // dimensions (set to 1 if b is the most major dimension). There are a * c + // many rows that we need to sort. We iterate through these, calculate a + // 'base_offset' value which points to the first element in that row, and add + // i * c for accessing the 'i'-th element in that row. + + int64 sort_dimension_elements = b; + int64 num_iteration_elements = a * c; + int64 sort_dimension_offset = c; + + std::unique_ptr[]> row_to_sort( + new std::pair[sort_dimension_elements]); + std::unique_ptr reordered_values( + new std::string[sort_dimension_elements]); + for (int64 index = 0; index < num_iteration_elements; ++index) { + // 'index' can be split into two values which index into the 'c' dimension + // and the 'a' dimension, respectively. 'index' % 'c' is the index into the + // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When + // calculating the base offset, we need to multiply the index into the 'a' + // dimension with 'b' * 'c'. + // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. + int64 base_offset = + index % sort_dimension_offset + + (index - index % sort_dimension_offset) * sort_dimension_elements; + // TODO(b/26783907): We could define a custom iterator class that references + // both arrays. Then we could avoid the intermediate copy. However this + // would become more complicated, and it is not clear if the benefit is high + // enough. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + row_to_sort[i] = + std::make_pair(keys[base_offset + i * sort_dimension_offset], i); + } + KeyValueSort(row_to_sort.get(), sort_dimension_elements); + for (int64 i = 0; i < sort_dimension_elements; ++i) { + keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + } + if (values == nullptr) { + continue; + } + + // Reorder the values according to the order defined by the keys. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + + reordered_values[i] = std::string(values + memory_index, + values_primitive_type_size_in_bytes); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + memcpy(values + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes); + } + } +} +} // namespace + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( + int8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( + uint8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( + int16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( + uint16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( + int32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( + uint32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( + float* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( + int64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( + uint64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( + double* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h new file mode 100644 index 0000000000000000000000000000000000000000..28e35e82c18cbf078f8a1e7f5b818bf839d3d3df --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' +// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. +// If 'values' is not nullptr, the elements in 'values' are reordered in such a +// way that if the element at index 'i' in 'keys' was moved to index 'j', the +// element at index 'i' in 'values' is also moved to index 'j' (which means that +// the same elements correspond to each other as before). +extern void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS8( + tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU8( + tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS16( + tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU16( + tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS32( + tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU32( + tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF32( + float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS64( + tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU64( + tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF64( + double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index bf98064647f4c29ba689902da4d737e1922391d3..9ec0c8f65705db335379649def746921e6b05bea 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index a0cd8ee2d2be10bcee9c2e216e24908d949e2d7b..5cdac203af2e7a1f8f3aebda965447ba75e9934e 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/core/platform/logging.h" namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 8b00ae9e47eeed26ffe80707b89593b267e8dbb8..a383b4a4a00f9b8d49a88e8349793a3a90d8da7b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#include "absl/container/flat_hash_map.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace cpu { @@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { // This is mutated from within `GetTargetTransformInfoFor` which is // semantically a getter (and thus `const`); and is therefore declared // mutable. Making this mutable is okay because it has cache semantics. - mutable tensorflow::gtl::FlatMap + mutable absl::flat_hash_map target_transform_info_cache_; llvm::TargetMachine* target_machine_; }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index c55206eee7ae3c6e4410c59aebf529de98fd2de8..4b129c95d46d8b5a119e5d23eef387daf7863cce 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -180,3 +180,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_key_value_sort_test", + srcs = ["cpu_key_value_sort_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3934c03a04c978009282b3cd0d39bacf9b12a356 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuKeyValueSortTest : public CpuCodegenTest {}; + +TEST_F(CpuKeyValueSortTest, SortR1) { + const string hlo_text = R"( +HloModule KeyValueSort + +ENTRY main { + a = f32[10] parameter(0) + + ROOT result = f32[10] sort(f32[10] a), dimensions={0} +} +)"; + + string filecheck_pattern = R"( +CHECK: call void @__xla_cpu_runtime_KeyValueSort +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 7af51db55af44ae1e437ea8e4de7427012cad82f..b35fd9dad877c319c3d0110c96a00aeefa78769e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) { CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 - CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} )"; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index 8fe65f488a2f0c4031926fa4c5f02dcf5473568d..cc38b81455b5a35cdcd07ac1dfb80cc7b101a7bc 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) { auto shape = ShapeUtil::MakeShape(U8, {length}); string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - length, bytes.data(), bytes.size()); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, - bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } // Performs the acquire/release sequence on the outfeed, as the generated CPU @@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) { void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - length, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - length, buffer, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } TEST_F(InfeedManagerTest, SingleThreadedSequential) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); xfeed->infeed()->EnqueueBuffersAtomically({b}); @@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); ProcessNextBuffer(a->length()); @@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TEST_F(InfeedManagerTest, MultiThreaded) { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); const int32 length = 64; @@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->outfeed()->EnqueueBuffersAtomically({b}); ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc index d124f74d19d83269be96ee34a6b4b2a8d00a978f..661539cccb4ef27a49a73f97a0a8b0d9dfc77061 100644 --- a/tensorflow/compiler/xla/service/defuser.cc +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) { fusion_instruction->fused_instructions_computation(); // A map from fused instruction to its defused clone. - tensorflow::gtl::FlatMap + absl::flat_hash_map defused_instructions; // Initialize map to contain the fusion instruction parameters mapping // to the operands of the fusion instruction. diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index c326beb899f9a434d772c0fda032efc9113b6f42..aaa41fc4fe779cdf01a34e86855cac02552ee383 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -25,7 +25,7 @@ namespace xla { // A pass which replaces all fusion instructions with the equivalent un-fused // instructions. -class Defuser : public HloPassInterface { +class Defuser : public HloModulePass { public: Defuser() {} ~Defuser() override {} diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ba2a674d9af547ad574ae49e1e87f3afcaf6112a..b3549acfc291a54b2345b006310613c3a45a4b47 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -24,7 +24,7 @@ namespace xla { namespace { // Pass which strips control dependencies from all instructions in the module. -class ControlDepRemover : public HloPassInterface { +class ControlDepRemover : public HloModulePass { public: ControlDepRemover() = default; absl::string_view name() const override { return "control-dep-remover"; } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 7be70add2f7566376b3179740e411d6341badf7c..46dcc3a438cbdf3ff1b3c99fa15b35ee7a4e280e 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -30,7 +30,7 @@ namespace xla { // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, // and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloPassInterface { +class Despecializer : public HloModulePass { public: Despecializer(); absl::string_view name() const override { return "despecializer"; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 5761573791d90e45c65b55124a4bae3c5b929ef1..68d01d75a2ed3d7eaadb03a46ba3bd20f43a9ffc 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index fc38e317001695921d20f9bbe5775e61a8eeaa45..40e7a3b4c25ff20674de0cce3fe2907fc43a5cb9 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -23,7 +23,7 @@ namespace xla { // DotDecomposer is a pass which decomposes batch Dot operations into a // sequence of smaller (R2) Dot operations. -class DotDecomposer : public HloPassInterface { +class DotDecomposer : public HloModulePass { public: // Decomposes batch Dot operations when 'decompose_batch_dot' is true. DotDecomposer(bool decompose_batch_dot = true) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4bb1e071d8da75d0219d0b5cc9a6d16f1750a191..515267edd7caf42e04ebe638b99006db8967ea30 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { - if (prim_type != F32) { - // TODO(b/34339814): Implement inverse erf for F64. + if (prim_type != F16 && prim_type != F32 && prim_type != F64) { return Unimplemented( "Inverse erf is only implemented for element " - "type F32."); + "types F16, F32 and F64."); } - auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(b_->getFloatTy(), f); + + // Upcast half to float. + if (prim_type == F16) { + x = b_->CreateFPExt(x, b_->getFloatTy()); + } + + auto get_float = [&](const double f) { + return llvm::ConstantFP::get(x->getType(), f); }; - auto multiply_add = [&](absl::Span coefficients, + auto multiply_add = [&](absl::Span coefficients, llvm::Value* w) { - llvm::Value* p = getFloat(coefficients.front()); + llvm::Value* p = get_float(coefficients.front()); coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = FAdd(FMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), get_float(coefficient)); } return p; }; // Approximation for inverse error function from // Giles, M., "Approximating the erfinv function". - // The approximation has the form: - // w = log((1-x)*(1+x)) + // The approximation has the form (float version): + // w = -log((1-x)*(1+x)) // if ( w < 5 ) { // w = w - 2.5 // p = sum_{i=1}^n lq[i]*w^i @@ -879,46 +884,124 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {b_->getFloatTy()}); + module_, llvm::Intrinsic::log, {x->getType()}); - llvm::Value* w = FNeg( - Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg(Call( + logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))})); llvm::Value* p_addr = - llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); + llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_); + + if (prim_type == F16 || prim_type == F32) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(2.5f)); + absl::Span lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + Store(p, p_addr); + } - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); - // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, b_); - { - llvm::Value* lw = FSub(w, getFloat(2.5f)); - absl::Span lq{ - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - llvm::Value* p = multiply_add(lq, lw); - Store(p, p_addr); - } + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f)); + absl::Span gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + Store(p, p_addr); + } - // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, b_); - { - llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - - llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); - absl::Span gq{ - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - llvm::Value* p = multiply_add(gq, gw); - Store(p, p_addr); - } + SetToFirstInsertPoint(if_data.after_block, b_); + } else { + DCHECK(prim_type == F64); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_); + + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(3.125)); + absl::Span c{ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + llvm::Value* p = multiply_add(c, lw); + Store(p, p_addr); + } - SetToFirstInsertPoint(if_data.after_block, b_); + SetToFirstInsertPoint(if_data.false_block, b_); + llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_); + SetToFirstInsertPoint(if_data_second.true_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25)); + absl::Span t1{ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635}; + llvm::Value* p = multiply_add(t1, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data_second.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0)); + absl::Span t2{ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221}; + llvm::Value* p = multiply_add(t2, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, b_); + } llvm::Value* p = Load(p_addr); - return FMul(p, x); + x = FMul(p, x); + // Trunc back to half if needed. + if (prim_type == F16) { + x = b_->CreateFPTrunc(x, b_->getHalfTy()); + } + return x; } StatusOr ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index 3cccec9862e0f92df478006939552099868121b9..986970f8862472d1db7564254a9c1277750bb6eb 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -26,7 +26,7 @@ namespace xla { // Flattening associates each call site with a unique computation (for // sequential calling contexts) This simplifies buffer assignment and // points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloPassInterface { +class FlattenCallGraph : public HloModulePass { public: absl::string_view name() const override { return "flatten-call-graph"; } diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..1208a7dda87d7b2a6ad7113e2604e8b9a0fa045b --- /dev/null +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 7bd9ea598417a931d2df507d472c6a60be05e0bc..2b39359aae9fc01f1a88a2594108b2772788e826 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -23,7 +23,7 @@ namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloPassInterface { +class GatherExpander : public HloModulePass { public: absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 64b96836280718f13ac5ee9f4a497ed54a273b19..62da43d68a71981ff871949d17aa30bacef0ce8c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -68,9 +68,7 @@ cc_library( # srcs = [ # "partition_assignment_test.cc", # ], -# tags = [ -# "requires-gpu-sm35", -# ], +# tags = tf_cuda_tests_tags(), # deps = [ # ":partition_assignment", # "//tensorflow/core:stream_executor_no_cuda", @@ -93,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -359,6 +358,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -373,7 +373,6 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":backend_configs", - ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -405,6 +404,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) @@ -414,6 +414,8 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":backend_configs", + ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -422,8 +424,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -432,6 +436,7 @@ cc_library( srcs = ["cudnn_convolution_rewriter.cc"], hdrs = ["cudnn_convolution_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -472,6 +477,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -504,6 +510,7 @@ cc_library( "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -537,6 +544,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -583,6 +591,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", @@ -596,14 +605,11 @@ cc_library( hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) @@ -656,6 +662,7 @@ cc_library( deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", + ":cudnn_fused_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -699,7 +706,6 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -713,6 +719,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -774,7 +781,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -783,6 +789,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -874,16 +881,6 @@ cc_library( ], ) -cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], @@ -967,3 +964,19 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "cudnn_fused_convolution_rewriter", + srcs = ["cudnn_fused_convolution_rewriter.cc"], + hdrs = ["cudnn_fused_convolution_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 640c6392b8b820c708b853c2a3cea4d4116e85a8..78e14d860e31ace2fcb3f51fb8e0c40a0ea5f3dd 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -24,4 +24,18 @@ message CudnnConvBackendConfig { // true, cudnn may choose not to use tensor cores, e.g. because the GPU or // selected algorithm doesn't support it. bool tensor_ops_enabled = 2; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // CudnnConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. It is with type + // stream_executor::dnn::ActivationMode. + int64 activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 3a23ac1d634161628b2bd2589d0260022868ba36..4effea637d01bf23b54d341b77306b20b1b133c8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,37 +29,38 @@ limitations under the License. namespace xla { namespace gpu { -using se::dnn::AlgorithmDesc; +ConvolutionThunk::ConvolutionThunk( + const HloCustomCallInstruction* cudnn_call, + std::vector operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + operand_buffers_(std::move(operand_slices)), + result_buffer_(result_slice), + scratch_buffer_(scratch_slice), + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - CudnnConvParams params; + std::vector operand_se_buffers; + for (const auto& buffer : operand_buffers_) { + operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + se::DeviceMemoryBase result_buffer = + buffer_allocations.GetDeviceAddress(result_buffer_); - params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); - params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); - params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); - // Figure out which of output/input/filter is the result produced by - // this op, and write the result tuple. - void* result_ptr = [&] { - switch (params.kind) { - case CudnnConvKind::kForward: - return params.output_buf.opaque(); - case CudnnConvKind::kBackwardInput: - return params.input_buf.opaque(); - case CudnnConvKind::kBackwardFilter: - return params.filter_buf.opaque(); - } - }(); - void* ptrs[] = {result_ptr, scratch.opaque()}; + void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory tuple_addr( buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); stream->ThenMemcpyH2D(ptrs, &tuple_addr); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d7d1f91fba7239ed1670119f5df623d025c1d368..f53bc541983378819dba36489dd69c348f50af32 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // Note that "output" here doesn't refer to the output from running this - // thunk, but rather to the "output" of a hypothetical forward convolution - // that corresponds to this input+filter+output triple. That is, the result - // generated by this thunk is "output" for forward convs, "input" for - // backward-input convs, and "filter" for backward-filter convs. + // operand_slices should be in the same order as cudnn_call->operands(). ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, - BufferAllocation::Slice input_slice, - BufferAllocation::Slice filter_slice, - BufferAllocation::Slice output_slice, + std::vector operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice tuple_result_slice) - : Thunk(Kind::kConvolution, cudnn_call), - cudnn_call_(cudnn_call), - input_buffer_(std::move(input_slice)), - filter_buffer_(std::move(filter_slice)), - output_buffer_(std::move(output_slice)), - scratch_buffer_(std::move(scratch_slice)), - tuple_result_buffer_(std::move(tuple_result_slice)) {} + BufferAllocation::Slice tuple_result_slice); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk { private: const HloCustomCallInstruction* cudnn_call_; - BufferAllocation::Slice input_buffer_; - BufferAllocation::Slice filter_buffer_; - BufferAllocation::Slice output_buffer_; + std::vector operand_buffers_; + BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; BufferAllocation::Slice tuple_result_buffer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index 6e2e330edd4beabe0b395f05b80d57612d63f110..c3f58508ddd4451312325b0d440473515812dac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -52,7 +52,7 @@ namespace gpu { // The GPU backend does not implement a lowering for the batchnorm HLOs -- it // expects them to be lowered to cudnn calls via this pass or to HLO soup via // BatchNormRewriter. -class CudnnBatchNormRewriter : public HloPassInterface { +class CudnnBatchNormRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index c607aea1a8c74057444467cecd7087f967bc7ee4..6d4a72038fb3f6ed657ee36eee82c3e5261d27b8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -76,54 +76,24 @@ StatusOr> ScratchAllocator::AllocateBytes( return se::DeviceMemory(buffer_addr); } -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output shapes. This works around b/68264959, an integer -// overflow in cuDNNv5 and cuDNNv6. -bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, - const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums, - se::StreamExecutor* stream_exec) { - // Skip this check for cudnn7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); - if (version.ok() && version.ValueOrDie().major_version() >= 7) { - return true; - } - - int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); - int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); - int64 in_cols = - dnums.input_spatial_dimensions_size() == 1 - ? 1 - : input_shape.dimensions(dnums.input_spatial_dimensions(1)); - int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); - - int64 total_size = CeilOfRatio(batch, int64{16}) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - - const int64 threshold = 1L << 31; - return total_size < threshold; -} - std::vector GetAlgorithms(CudnnConvKind kind, - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) { std::vector algorithms; + bool succ = false; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = + stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); break; case CudnnConvKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + case CudnnConvKind::kForwardActivation: + succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); break; } + DCHECK(succ); return algorithms; } @@ -175,21 +145,13 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr> +StatusOr CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - const HloCustomCallInstruction* instr) { - CudnnConvParams params; - TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); - - const Shape& input_shape = *params.input_shape; - const Shape& filter_shape = *params.filter_shape; - const Shape& output_shape = *params.output_shape; - - CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); - CHECK_EQ(input_shape.element_type(), output_shape.element_type()); + HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. - const bool cross_check_enabled = input_shape.element_type() == xla::F16; + const bool cross_check_enabled = + instr->shape().tuple_shapes(0).element_type() == xla::F16; // Don't run this function concurrently on the same GPU. // @@ -221,25 +183,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( allocator = &*se_allocator; } - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(params.input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(params.filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(params.output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); size_t left_over_bytes = buffer.size() % 4; CHECK_EQ(0, left_over_bytes % 2); @@ -257,51 +206,56 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( DeviceMemoryBase left_over( static_cast(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - }; - initialize_f16(params.input_buf); - initialize_f16(params.filter_buf); - initialize_f16(params.output_buf); - } else { - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the - // inputs might affect performance if e.g. the inputs contain denormals, and - // this is easy enough. - stream.ThenMemZero(¶ms.input_buf, params.input_buf.size()) - .ThenMemZero(¶ms.filter_buf, params.filter_buf.size()) - .ThenMemZero(¶ms.output_buf, params.output_buf.size()); - } - - DeviceMemoryBase* result_buf = [&] { - switch (params.kind) { - case CudnnConvKind::kBackwardFilter: - return ¶ms.filter_buf; - case CudnnConvKind::kBackwardInput: - return ¶ms.input_buf; - case CudnnConvKind::kForward: - return ¶ms.output_buf; + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); } - }(); + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + std::vector operand_buffers; + for (const auto* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); + } + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); - const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); optional comparator; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm sufficies. It doesn't make // this algorithm considered correct, though. optional first_algorithm; - for (const AlgorithmDesc& alg : - GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - params.algorithm = AlgorithmConfig(alg); - bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, - &profile_result) + backend_config.set_algorithm(alg.algo_id()); + backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); + TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers), + result_buffer, &scratch_allocator, + &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { @@ -312,7 +266,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( .xla_gpu_crash_on_verification_failures(); if (comparator.has_value()) { StatusOr result = comparator->CompareEqual( - se::DeviceMemory(*result_buf)); + se::DeviceMemory(result_buffer)); if (!result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(*first_algorithm) << " against " @@ -330,7 +284,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( } } else if (cross_check_enabled) { auto comp = F16BufferComparator::Create( - se::DeviceMemory(*result_buf), compiler_, allocator, + se::DeviceMemory(result_buffer), compiler_, allocator, &stream); if (comp.ok()) { comparator.emplace(comp.ConsumeValueOrDie()); @@ -362,9 +316,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << AlgorithmToString(best_result.algorithm()) << ", takes " << best_result.elapsed_time_in_ms() << "ms, and uses " << best_result_bytes_used << "B of scratch memory."; - return std::make_tuple(best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used); + return AutotuneResult{best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used, + absl::Milliseconds(best_result.elapsed_time_in_ms())}; } return InternalError( @@ -377,40 +332,37 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - StatusOr> alg_scratch_and_tc = + StatusOr best_algo_or = PickBestAlgorithm(Cast(instr)); - - if (!alg_scratch_and_tc.ok()) { - LOG(ERROR) << alg_scratch_and_tc.status(); + if (!best_algo_or.ok()) { + LOG(ERROR) << best_algo_or.status(); return false; } - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = - alg_scratch_and_tc.ConsumeValueOrDie(); - - VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " - << NumBytesToString(scratch_bytes) + auto best_algo = std::move(best_algo_or).ValueOrDie(); + VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm + << " and " << NumBytesToString(best_algo.scratch_bytes) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); - Shape new_call_shape = - ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {scratch_bytes})}); + Shape new_call_shape = ShapeUtil::MakeTupleShape( + {instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); - CudnnConvBackendConfig backend_config; - backend_config.set_algorithm(algorithm); - backend_config.set_tensor_ops_enabled(tensor_ops_enabled); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + instr->backend_config()); + backend_config.set_algorithm(best_algo.algorithm); + backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction( - instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0), - instr->mutable_operand(1)})); + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + + VLOG(1) << "Replacing convolution " << instr->ToString() << " with " + << new_call->ToString(); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index f79b113f8fac0190adef9a8d68d1617710b1402c..136c32210a4afbd60cf9b13863befdba9b712a9d 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -30,7 +31,7 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloPassInterface { +class CudnnConvolutionAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, @@ -47,10 +48,16 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + struct AutotuneResult { + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + absl::Duration runtime; + }; + StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr> PickBestAlgorithm( - const HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 228379a2488a8564564e8b5e35a863553f4bbac2..437d25727e20afb1600daccd6f925f84db642855 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -35,6 +36,32 @@ namespace gpu { namespace { +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + return custom_call; +} + bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) { return std::make_tuple(true, new_window, dnums, rhs); } +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; +} + // Tries to rewrite a single convolution into a call to cudnn. StatusOr RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); @@ -462,24 +495,24 @@ StatusOr RunOnInstruction(HloInstruction* conv) { std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - return CreateCudnnConvBackwardFilter( - conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count()); } std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - return CreateCudnnConvBackwardInput(conv->shape(), - conv->mutable_operand(0), rhs, window, - dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers(), - conv->feature_group_count()); + return CreateCudnnConv( + kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers(), conv->feature_group_count()); } return nullptr; @@ -489,6 +522,12 @@ StatusOr RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << custom_call->ToString(); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index fbe7e9849458e9d52be15b3f5610479ab68ffa4c..8d7c6fdab510407428a115579a90e8cf85e9fad2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -24,7 +24,7 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloPassInterface { +class CudnnConvolutionRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn-convolution-rewriter"; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 2a86ac265e4d6a6502162ac33b04b0ee362ce49e..a809c22b336ef83c6e3d8575997119e3a5288615 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -37,6 +39,42 @@ using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; using se::dnn::ProfileResult; +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional fusion; +}; + // A StreamExecutor ScratchAllocator that wraps a single XLA allocation, // returning it (in its entirety) the first time Allocate() is called. class ScratchBufAllocator : public se::ScratchAllocator { @@ -92,9 +130,9 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; - VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; - VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; @@ -186,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, switch (kind) { case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveWithAlgorithm( input_descriptor, input_buf, filter_descriptor, filter_buf, convolution_descriptor, output_descriptor, &output_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardDataWithAlgorithm( filter_descriptor, filter_buf, output_descriptor, output_buf, convolution_descriptor, input_descriptor, &input_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardFilterWithAlgorithm( input_descriptor, input_buf, output_descriptor, output_buf, convolution_descriptor, filter_descriptor, &filter_buf, scratch_allocator, algorithm, profile_result); break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } } if (!stream->ok()) { @@ -214,32 +302,105 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, return Status::OK(); } -} // anonymous namespace +// Returns the cudnn convolution parameters generated from conv, which must be a +// custom-call to a cudnn convolution. +StatusOr GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + CudnnConvParams params; + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config()); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); + const auto& lhs_shape = conv->operand(0)->shape(); + const auto& rhs_shape = conv->operand(1)->shape(); + const auto& conv_result_shape = conv->shape().tuple_shapes(0); + + params.kind = kind; + params.window = &conv->window(); + params.dnums = &conv->convolution_dimension_numbers(); + params.feature_group_count = conv->feature_group_count(); + params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); -string CudnnConvKindToString(CudnnConvKind kind) { switch (kind) { case CudnnConvKind::kForward: - return "forward"; - case CudnnConvKind::kBackwardFilter: - return "backward_filter"; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; case CudnnConvKind::kBackwardInput: - return "backward_input"; + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } + } } + return params; } -Status RunCudnnConvolution(CudnnConvParams params, +} // anonymous namespace + +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(params, &scratch_allocator, stream, - profile_result); + return RunCudnnConvolution(conv, operand_buffers, result_buffer, + &scratch_allocator, stream, profile_result); } -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = params.output_shape->element_type(); + TF_ASSIGN_OR_RETURN(CudnnConvParams params, + GetCudnnConvParams(conv, operand_buffers, result_buffer)); + + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvolutionImpl(params, scratch_allocator, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 381aa37a1b1405e00d62adf9855e9229482f5b86..61aec1ceccec0f253f9ddaa688d64cacea800cf3 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -27,52 +30,8 @@ namespace gpu { // This file contains low-level routines for running cudnn convolutions. -// Different types of convolutions supported by cudnn. -// -// A way to think about these is that a convolution is defined by three arrays -// -- the "input", the "filter", and the "output" -- and given any two of these, -// we can compute the third. For example, a backward-input convolution takes as -// input a filter and an "output" and produces an "input" such that if one were -// to do a forward convolution of "input" using filter, the result would be -// something with the same shape as "output". -// -// This way of thinking is not correct if you look at the values produced. For -// example, a backward-input convolution is not actually the mathematical -// inverse of a forward convolution. But it's right as far as the shapes and -// "connectivity" (i.e. which elements of the input affect which elements of -// the output) are concerned. -enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter -}; - -struct CudnnConvParams { - CudnnConvKind kind; - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; - se::DeviceMemoryBase input_buf; - se::DeviceMemoryBase filter_buf; - se::DeviceMemoryBase output_buf; - const Window* window; - const ConvolutionDimensionNumbers* dnums; - int64 feature_group_count; - se::dnn::AlgorithmConfig algorithm; -}; - -// Converts a CudnnConvKind value to a string. -string CudnnConvKindToString(CudnnConvKind kind); - // Calls into cudnn to run the specified convolution. // -// Note that depending on the value of CudnnConvKind, the result of this call -// may be written into input_buf, filter_buf, or output_buf! -// -// At the moment convolution with half data type is implemented with cudnn -// PSEUDO_HALF configuration, that is, the input values are half and the -// internal computation type is float. -// // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In // theory the second one shouldn't be necessary -- users of this function could @@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..d508cbc2e1cf4cec071857ec5e048e8bd0be1015 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc @@ -0,0 +1,279 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast(conv_instr); + auto bias_broadcast = + CastOrNull(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull(alpha_conv_instr), + CastOrNull(alpha_side_input_instr)}; +} + +StatusOr> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast(instr)->literal().Convert(F64)); + return alpha.GetFirstElement(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + config.set_activation_mode( + static_cast(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << new_conv->ToString(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr CudnnFusedConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h similarity index 56% rename from tensorflow/compiler/xla/service/gpu/gpu_options.h rename to tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h index 498d4a94955cb2c50e0b165f28ded44ac1c0bfff..bd12aadded9dd9e19bc695ddc11e5529931a306a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h @@ -13,21 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ -#include "tensorflow/compiler/xla/service/hlo_module_config.h" - -// Helper functions for querying options that are specific to the GPU backend. +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); +class CudnnFusedConvolutionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04ddc31edf723c056805ae5aad994e55..6dcdaf1cfe06e446deed847aaf29088a7ed10e13 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 7e3f5775b8d97f43a0bba201d24f34c2d337fabb..f19996edfe3dd923aa686a19621ce28a4aed5a45 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -32,7 +32,7 @@ namespace gpu { // 2) The result of merging the fusion instruction into its users would not // increase bytes transferred. // -class FusionMerger : public HloPassInterface { +class FusionMerger : public HloModulePass { public: absl::string_view name() const override { return "fusion merger"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 75f414e47fe3edcc1b10b392ed5cc5038be6c190..e2ab00ce41c9e23e91449f249620d61d0f7736ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -27,22 +28,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -StatusOr GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - HloInstruction*& copy = hlo_to_copy_map_[hlo]; - if (copy == nullptr) { - TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); - } - return copy; -} - StatusOr GpuCopyInsertion::Run(HloModule* module) { CopyInsertion generic_copy_insertion; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 8ffae18fe820aa01701731ee56a83aeacf0eab0d..4c7e38ffeb60f87a4f27e212572ae31cca8e0947 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -25,20 +25,11 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 31a9f9b1beb81da81a06f6dc8e7c13c105514092..57426327822d95a42f407ed7488f35acfd3623d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { } module_spec.AddCudaPtxInMemory(ptx().c_str()); - tensorflow::gtl::FlatMap globals; + absl::flat_hash_map globals; se::ModuleHandle module_handle; executor->LoadModule(module_spec, &module_handle); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 38b0f8f15bd28cf2659e4a53b6634e981545716b..0e276282e40fba0ae4881a51dad0c7c9e8d1c081 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -101,7 +101,7 @@ class GpuExecutable : public Executable { const PointsToSet& GetRootPointsToSet() const; using BufferAllocToDeviceMemoryMap = - tensorflow::gtl::FlatMap; + absl::flat_hash_map; // Loads the PTX or CUBIN for this executable into `executor` and resolves the // globals corresponding to constant buffers. Returns a map mapping buffer diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index bbb3340760c8330bd6570f33382f004315c6d0bd..9c64b4d10c9d1b172f7bd89b5fdacda893488bf8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // his pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloPassInterface { +class GpuHloSupportChecker : public HloModulePass { public: GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index d033faee8d25ed81a1483f8314652ef999ab36c5..8c9a8adc614e748ffc431dd2dc7fa1de3d38cf40 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,45 +91,46 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // operands and the output shape. Depending on the underlying algorithm, one of // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints) { - CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); - Shape input_shape; - Shape filter_shape; - Shape output_shape; - const auto& target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->shape().tuple_shapes(0); - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_shape = instr->shape().tuple_shapes(0); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->operand(0)->shape(); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->shape().tuple_shapes(0); - output_shape = instr->operand(1)->shape(); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + HloCustomCallInstruction* instr, LayoutConstraints* constraints) { + Shape lhs_shape = instr->operand(0)->shape(); + Shape rhs_shape = instr->operand(1)->shape(); + Shape result_shape = instr->shape().tuple_shapes(0); + + Shape* input_shape; + Shape* filter_shape; + Shape* output_shape; + + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + input_shape = &lhs_shape; + filter_shape = &rhs_shape; + output_shape = &result_shape; + break; + case CudnnConvKind::kBackwardInput: + input_shape = &result_shape; + filter_shape = &rhs_shape; + output_shape = &lhs_shape; + break; + case CudnnConvKind::kBackwardFilter: + input_shape = &lhs_shape; + filter_shape = &result_shape; + output_shape = &rhs_shape; + break; } { DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( - std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), - *output_shape.mutable_layout()), + std::tie(*input_shape->mutable_layout(), + *filter_shape->mutable_layout(), + *output_shape->mutable_layout()), StreamExecutorConvLayoutsToXlaLayouts( instr->convolution_dimension_numbers(), input, filter, output)); } @@ -141,24 +143,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( instr, /*index=*/{0})); // Set layouts of the instructions' shapes. - if (target == kCudnnConvForwardCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardInputCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(result_shape.layout(), *call_result_buf)); + // instr->operand(2), if exists, is the bias buffer. There is no need to + // assign layout to it, as it has only one dimension. + + // instr->opernad(3), if exists, is the side input buffer. + if (instr->operand_count() == 4) { + if (kind != CudnnConvKind::kForwardActivation) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward: %s", + instr->ToString()); + } + // The side input layout must match the output layout. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3)); } return Status::OK(); } @@ -173,8 +174,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( ++iterator) { HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { - TF_RETURN_IF_ERROR( - AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); + TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( + Cast(instruction), constraints)); } // For batched dot we require the default layout. @@ -212,16 +213,6 @@ Status GpuLayoutAssignment::AddBackendConstraints( return Status::OK(); } -bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) { - // - Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - // - Inputs to cudnn convolution require custom layouts handled in - // AddBackendConstraints. - return !IsCustomCallToDnnBatchNorm(*instruction) && - !IsCustomCallToDnnConvolution(*instruction); -} - Status GpuLayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f..6a48e55fd2e784f80a50f4565107db177fb43bfc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -29,8 +30,11 @@ namespace gpu { class GpuLayoutAssignment : public LayoutAssignment { public: explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, se::StreamExecutor* stream_executor) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} @@ -42,12 +46,10 @@ class GpuLayoutAssignment : public LayoutAssignment { Status PropagateBufferConstraint( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) override; - bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) override; private: Status AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints); + HloCustomCallInstruction* instr, LayoutConstraints* constraints); se::StreamExecutor* stream_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index fbc8ddf599570b90e93eb463a1fd6c275b73711c..04681cfcec792d86eed95585262691932b07b269 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the @@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); Shape expected_shape = diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 4d5d8e99f88149aabfd0a4aeafc7e6724d29418d..b61f0387392d2301109a484ca5c1f65f18882265 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } // Compute the precise number of operands to the new fusion. - tensorflow::gtl::FlatSet operands( - a->operands().begin(), a->operands().end()); + absl::flat_hash_set operands(a->operands().begin(), + a->operands().end()); operands.insert(b->operands().begin(), b->operands().end()); // If there's an edge between `a` and `b`, don't count it: We're fusing that // producer -> consumer relationship. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 22f43bc08bd08abd735f88f32f28c528499cf3d2..ec3d8f9405840bb7be97ba5cd5725a4ac68a15a8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget = "__cudnn$convBackwardInput"; const char* const kCudnnConvBackwardFilterCallTarget = "__cudnn$convBackwardFilter"; +const char* const kCudnnConvBiasActivationForwardCallTarget = + "__cudnn$convBiasActivationForward"; bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { @@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || target == kCudnnConvBackwardInputCallTarget || - target == kCudnnConvBackwardFilterCallTarget; + target == kCudnnConvBackwardFilterCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget; } bool ImplementedAsLibraryCall(const HloInstruction& hlo) { @@ -145,59 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv(const char* call_target, - const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - HloComputation* computation = lhs->parent(); - - // This call returns a tuple of (conv_result, scratch_memory), where - // conv_result is the actual result of the convolution, and scratch_memory is - // temporary memory used by cudnn. - // - // At the moment, we don't know how much scratch memory this conv is going to - // use, so we put u8[0] in this place. Later on another pass will choose - // which conv algorithm to use, and at that point we'll modify the shape of - // this second tuple element. - Shape call_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - - HloInstruction* custom_call = computation->AddInstruction( - HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); - custom_call->set_window(window); - custom_call->set_convolution_dimension_numbers(dnums); - custom_call->set_feature_group_count(feature_group_count); - return custom_call; -} - -HloInstruction* CreateCudnnConvForward(const Shape& shape, - HloInstruction* input, - HloInstruction* kernel, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums, feature_group_count); -} - -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums, feature_group_count); -} - -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums, feature_group_count); -} - bool IsReductionToVector(const HloInstruction& reduce) { if (HloOpcode::kReduce != reduce.opcode()) { return false; @@ -288,41 +238,35 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } -Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, - CudnnConvParams* params) { - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); - const auto& target = custom_call->custom_call_target(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); - - params->window = &custom_call->window(); - params->dnums = &custom_call->convolution_dimension_numbers(); - params->feature_group_count = custom_call->feature_group_count(); - params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( - backend_config.algorithm(), backend_config.tensor_ops_enabled())); - +StatusOr GetCudnnConvKind( + const HloCustomCallInstruction* instr) { + absl::string_view target = instr->custom_call_target(); if (target == kCudnnConvForwardCallTarget) { - params->kind = CudnnConvKind::kForward; - params->input_shape = &lhs_shape; - params->filter_shape = &rhs_shape; - params->output_shape = &conv_result_shape; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params->kind = CudnnConvKind::kBackwardInput; - params->input_shape = &conv_result_shape; - params->filter_shape = &rhs_shape; - params->output_shape = &lhs_shape; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params->kind = CudnnConvKind::kBackwardFilter; - params->input_shape = &lhs_shape; - params->filter_shape = &conv_result_shape; - params->output_shape = &rhs_shape; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); + return CudnnConvKind::kForward; + } + if (target == kCudnnConvBackwardInputCallTarget) { + return CudnnConvKind::kBackwardInput; + } + if (target == kCudnnConvBackwardFilterCallTarget) { + return CudnnConvKind::kBackwardFilter; + } + if (target == kCudnnConvBiasActivationForwardCallTarget) { + return CudnnConvKind::kForwardActivation; + } + return InternalError("Unexpected call target: %s", target); +} + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + case CudnnConvKind::kForwardActivation: + return "forward with activation"; } - return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 09c455cc1e137b4a9836a58d5b70e62a4bfa120a..a64a616ab1329422d0197f4a7f99ec557a95f8ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -30,6 +29,33 @@ limitations under the License. namespace xla { namespace gpu { +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output +}; + +StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. @@ -95,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. @@ -104,28 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget; // kConvolution opcode. bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); -// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. -// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If -// you want just the conv result, you'll need to get-tuple-element the value -// returned by this function. -// -// The created cudnn call will use the default cudnn algorithm and no scratch -// space. -HloInstruction* CreateCudnnConvForward(const Shape& shape, - HloInstruction* input, - HloInstruction* kernel, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); - // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); @@ -150,11 +155,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); -// Populates params using conv, which must be a custom-call to a cudnn -// convolution. Does not modify any buffers in the params. -Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, - CudnnConvParams* params); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index b7c37bcf3ca910f10d18339dfe7f1d29f2a55c9e..47102347cbf3fbcc6f3979814eec3515f3af0e46 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -179,6 +179,21 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; llvm::Value* source = Load(source_address, "source"); + + // kCopy of RHS -> atomic store. + if (root_opcode == HloOpcode::kCopy && + (element_type == F32 || is_atomic_integral) && + computation.root_instruction()->operand(0)->opcode() == + HloOpcode::kParameter && + computation.root_instruction()->operand(0)->parameter_number() == 1) { + llvm::StoreInst* store = Store(source, output_address); + store->setAtomic(llvm::AtomicOrdering::Unordered); + // Derive a minimum alignment from the type. The optimizer can increase it + // later. + store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type)); + return true; + } + if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b669881026276eefe2ca6cbea74d79604dd13066..09486d291abe6898b223d10fb734b5d4f383db58 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const auto& target = custom_call->custom_call_target(); - BufferAllocation::Slice input_slice, filter_slice, output_slice; - - if (target == kCudnnConvForwardCallTarget) { - input_slice = lhs_slice; - filter_slice = rhs_slice; - output_slice = conv_result_slice; - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_slice = conv_result_slice; - filter_slice = rhs_slice; - output_slice = lhs_slice; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_slice = lhs_slice; - filter_slice = conv_result_slice; - output_slice = rhs_slice; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - thunk_sequence_->emplace_back(absl::make_unique( - Cast(custom_call), input_slice, filter_slice, - output_slice, scratch_slice, tuple_result_slice)); + Cast(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } @@ -1975,6 +1958,151 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { return Status::OK(); } +Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + + std::vector> thunks; + + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*operand); + auto destination_buffer = GetAllocationSlice(*scatter); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + } + + auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { + std::vector raw_window_multidim; + std::vector input_scatter_multidim; + std::vector raw_window_bounds; + + // Partition the index into window indices and scatter indices. + for (int64 i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates->shape().dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + dim_numbers.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64 raw_window_multidim_idx = 0; + std::vector input_window_multidim; + std::vector input_window_bounds; + for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape = scatter_indices->shape(); + if (dim_numbers.index_vector_dim() == + ShapeUtil::Rank(scatter_indices_shape)) { + scatter_indices_shape.add_dimensions(1); + scatter_indices_shape.mutable_layout()->add_minor_to_major( + dim_numbers.index_vector_dim()); + } + llvm_ir::IrArray scatter_indices_reshaped = + GetIrArray(*scatter_indices, *scatter) + .CastToShape(scatter_indices_shape, &b_); + + // Now load the indices corresponding to the current window from + // scatter_indices. + llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, + index.GetType()); + raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + llvm::Value* is_in_bounds = b_.getTrue(); + for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_index[dim_numbers.index_vector_dim()] = + raw_scatter_index_index.GetConstantWithIndexType(i); + + int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + llvm::Value* loaded_scatter_index = + scatter_indices_reshaped.EmitReadArrayElement(raw_scatter_index_index, + &b_, "scatter_index"); + // And add the index to our window index. This yields the output index. + llvm::Value* dim_offset = + Add(input_window_multidim[operand_dim], + IntCast(loaded_scatter_index, index.GetType(), + /*isSigned=*/true)); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64 max_index = operand->shape().dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = dim_offset >= 0 && dim_offset < dim_size-window_size+1 + // --> dim_offset u< dim_size-window_size+1 + is_in_bounds = + And(is_in_bounds, + ICmpULT(dim_offset, index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index(input_window_multidim, + index.GetType()); + llvm::Value* input_address = + GetIrArray(*updates, *scatter).EmitArrayElementAddress(index, &b_); + llvm::Value* output_address = + GetIrArray(*scatter, *scatter) + .EmitArrayElementAddress(input_window_index, &b_); + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + thunks.push_back( + BuildKernelThunk(scatter, + /*implements_whole_instruction=*/thunks.empty())); + + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + updates->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, + static_cast(thunks.back().get()), + ir_emitter_context_->llvm_module()); + + if (thunks.size() == 1) { + thunk_sequence_->push_back(std::move(thunks[0])); + } else { + thunk_sequence_->emplace_back( + absl::make_unique(std::move(thunks), scatter)); + } + return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + launch_dimensions, &b_) + .EmitLoop(IrName(scatter), + GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), + &b_)); +} + Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { thunk_sequence_->push_back( BuildKernelThunk(select, /*implements_whole_instruction=*/true)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index bd5db7205155dc6b15ddea069e172bbd8f419996..2e36e7235be89ab6b2909b2f133af76cc297edd2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c21f76f6eb1874bfa5a1d296c78ea0e3b9261eca..835924024b7b7de79624a369a69b07d72ac751ab 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; for (auto instr : instr1->operands()) { if (!IsProfitableOperand(instr)) { continue; @@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { bool changed = false; RecomputeReachability(); - tensorflow::gtl::FlatSet to_fuse; + absl::flat_hash_set to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, // then filter out instructions that will be no longer fusible because of diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index dfdcf1875dd3f5749bd1fd95ad0eeb8c11955887..5409f655896b14bfc810a6e46425812396fbb10a 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -74,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -175,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - pipeline.AddPass(); - pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -208,6 +206,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -230,14 +229,17 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // a layout-sensitive verifier! HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, stream_exec); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -283,8 +285,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -296,7 +300,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline reduce_pipeline("reduce-precision"); reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -322,8 +327,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -398,11 +405,11 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { "prefers >= 9.2.88). Compilation of XLA kernels below will likely " "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " "binary is sufficient."; - } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " "miscompile XLA code, leading to incorrect results or " "invalid-address errors.\n\nYou do not need to update to CUDA " "9.2.88; cherry-picking the ptxas binary is sufficient."; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 8e97774750344bfc141daa7d752300762c708613..c4a0b727cd3d9ae0af61c1752c1608cd4fb65d2d 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/node_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler { tensorflow::condition_variable compilation_done_cv_; }; - // Don't even think about switching this to FlatMap; iterator stability is - // critical here. - std::unordered_map + // Don't even think about switching this to flat_hash_map; iterator stability + // is critical here. + absl::node_hash_map compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index b0061fa6558ac92bffd3dff13e736421a62dc484..8f1f5a7bf5b36035bd7149f598df4515da4ff08c 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -36,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8; // there's additional room for speedups. Achieving those speedups without also // slowing other things down will likely require a more sophisticated heuristic, // possibly some form of auto-tuning. -static constexpr double kMaxBytesTouchedIncrease = 1.2; +// +// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4" +// special case inside PadShape won't fire. +static constexpr double kMaxBytesTouchedIncrease = 1.35; // Pads the given dimensions in the given shape up to a multiple of // kDesiredNumFeaturesFactor. static Shape PadShape(Shape s, absl::Span dims) { for (int64 dim : dims) { int64 dim_to_pad_size = s.dimensions(dim); - int64 new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + + // Round dim_to_pad_size up to the next multiple of + // kDesiredNumFeaturesFactor. + // + // Special case: dims of size 3 are rounded up to 4, not + // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia), + // this helps, but as of writing, it's not supported by anything in the + // cudnn docs. + int64 new_dim_to_pad_size; + if (dim_to_pad_size == 3) { + new_dim_to_pad_size = 4; + } else { + new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + } + s.set_dimensions(dim, new_dim_to_pad_size); } return s; @@ -87,38 +105,45 @@ static HloInstruction* PadInstruction(HloInstruction* instr, // Pads the input/output feature dimensions of the given cudnn convolution // custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr PadFeaturesDims(HloInstruction* conv) { +static StatusOr PadFeaturesDims(HloCustomCallInstruction* conv) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " "before CudnnConvolutionAlgorithmPicker."; - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); const Shape& result_shape = conv->shape().tuple_shapes(0); Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardFilter: + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); }(); Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardInput: + // RHS is "filter". + return PadShape(rhs->shape(), + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); }(); if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && @@ -128,18 +153,21 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { } Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // Result is "filter". + return PadShape(result_shape, + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); }(); // Check that padding wouldn't increase the total bytes read/written by this @@ -205,12 +233,20 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { return true; } -static std::vector GetRelevantConvs(HloComputation* comp) { - std::vector convs; +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16) { - convs.push_back(instr); + if (!IsCustomCallToDnnConvolution(*instr)) { + continue; + } + auto* custom_call = Cast(instr); + if (custom_call->operand(0)->shape().element_type() == F16 && + // TODO(timshen): Disable for fused conv for now. Implement it if it's + // needed. + custom_call->custom_call_target() != + kCudnnConvBiasActivationForwardCallTarget) { + convs.push_back(custom_call); } } return convs; @@ -219,7 +255,7 @@ static std::vector GetRelevantConvs(HloComputation* comp) { StatusOr PadForTensorCores::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 11dc56a64fda74cab12024e5f2c6fa2f63c9167d..e592a3774ec28605fda912298c74ca7976ff99ac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -30,7 +30,7 @@ namespace gpu { // targeting before running this pass. // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloPassInterface { +class PadForTensorCores : public HloModulePass { public: absl::string_view name() const override { return "pad for tensor cores"; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2a6415d0b6c973cb72c30b7a803b5f603c1d5e4d..ae7abca7c6ebb00b2d17c490ba3a41df2cc7415d 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -30,7 +31,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -161,12 +163,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. - Shape old_conv_shape = conv->shape().tuple_shapes(0); - VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward( - old_conv_shape, new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers(), conv->feature_group_count()); + std::vector operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), operands)); + new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -242,10 +246,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( - backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -308,9 +312,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( - new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. @@ -372,24 +379,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - std::vector convs; + std::vector convs; for (auto* instr : computation->instructions()) { if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); + convs.push_back(Cast(instr)); } } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index a622e894ed9c0d1534262e6b72a5f4ea7b7821ad..25cdf64c4cf01300869044d3e4d7c34c85626a5a 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -24,7 +24,7 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloPassInterface { +class PadInsertion : public HloModulePass { public: absl::string_view name() const override { return "pad insertion"; } diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index cf9f102d31305da15dabaf6247f23c5ca9a9e054..375f68a15957936151aee068582a714b62694af2 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -62,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - auto threads_per_core = device_desc.threads_per_core_limit(); - auto blocks_per_core = device_desc.blocks_per_core_limit(); - int64 threads_per_block; - if (threads_per_core != 0 && blocks_per_core != 0) { - threads_per_block = device_desc.threads_per_core_limit() / - device_desc.blocks_per_core_limit(); - } else { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { static std::atomic log_count{0}; if (log_count.fetch_add(1) < 8) { LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h index c2df83aaa4347a9439798acc6cfc2ba0db995232..52d38b6f20e8d61e2d4966ad15a5583a9cd2e945 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace gpu { @@ -34,7 +34,7 @@ class StreamAssignment { private: int stream_count_ = 1; // At least the main stream. - tensorflow::gtl::FlatMap hlo_to_stream_number_; + absl::flat_hash_map hlo_to_stream_number_; }; // Assigns GPU streams to instructions in `module`. diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index db4a33dc564b62b5fe54b725ea453a6fcbfb3287..1f0436278c5c55a5266ecdd87ad904472e12a5d7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -25,15 +25,17 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "gpu_codegen_test", testonly = True, srcs = ["gpu_codegen_test.cc"], hdrs = ["gpu_codegen_test.h"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -48,9 +50,7 @@ cc_library( tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -67,9 +67,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/core:test_main", @@ -79,9 +77,7 @@ tf_cc_test( tf_cc_test( name = "gpu_index_test", srcs = ["gpu_index_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -102,9 +98,7 @@ tf_cc_test( tf_cc_test( name = "gpu_infeed_test", srcs = ["infeed_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -125,9 +119,7 @@ tf_cc_test( tf_cc_test( name = "gpu_kernel_tiling_test", srcs = ["gpu_kernel_tiling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo", @@ -142,7 +134,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ldg_test", srcs = ["gpu_ldg_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -159,9 +151,7 @@ tf_cc_test( tf_cc_test( name = "gpu_noalias_test", srcs = ["gpu_noalias_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -178,9 +168,7 @@ tf_cc_test( tf_cc_test( name = "gpu_fusion_test", srcs = ["gpu_fusion_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -194,9 +182,7 @@ tf_cc_test( tf_cc_test( name = "gpu_unrolling_test", srcs = ["gpu_unrolling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -211,9 +197,7 @@ tf_cc_test( name = "gpu_alignment_test", testonly = True, srcs = ["gpu_alignment_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -225,3 +209,29 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cudnn_fused_convolution_rewriter_test", + srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "gpu_atomic_test", + srcs = ["gpu_atomic_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5632cac1862e21825888d94ab1eee5e1c9fd6800 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc @@ -0,0 +1,283 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CudnnFusedConvolutionRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_EQ(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convForward")) + << optimized_hlo_string; + EXPECT_NE(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convBiasActivationForward")) + << optimized_hlo_string; + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); + EXPECT_NE(absl::string_view::npos, + optimized_hlo.find("__cudnn$convForward")) + << optimized_hlo; + EXPECT_EQ(absl::string_view::npos, + optimized_hlo.find("__cudnn$convBiasActivationForward")) + << optimized_hlo; + } + } +}; + +TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, + TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b18c4c63714b4b3c06d7fa85f4a7a75b8e9ae12 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAtomicTest : public GpuCodegenTest {}; + +TEST_F(GpuAtomicTest, TestStore) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: store atomic{{.*}}unordered, align 4 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index e0f3a7e0e2869fa854c0229cd06bbdd641d99363..9220865867b770eebfb1ada8f31a5d24693a4b8d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,14 +18,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +using absl::flat_hash_map; +using absl::flat_hash_set; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( @@ -56,7 +58,7 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, @@ -88,7 +90,7 @@ StatusOr HeapSimulator::Run( const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/nullptr, memory_by_computation); @@ -115,8 +117,10 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + flat_hash_map> + live_buffers; + flat_hash_map> + used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, const BufferValue* buffer) { @@ -213,7 +217,7 @@ Status HeapSimulator::RunComputation( VLOG(4) << " Removing user " << instruction->name() << " from buffer " << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); - FlatSet* live_set = &it->second; + flat_hash_set* live_set = &it->second; live_set->erase(instruction); if (live_set->empty()) { live_buffers.erase(it); @@ -235,7 +239,8 @@ Status HeapSimulator::RunComputation( // that we should assign. // Make sure each buffer get reused at most once. - FlatSet reused_buffers; + flat_hash_set reused_buffers; + int64 alloc_size_by_instruction = 0; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -268,14 +273,15 @@ Status HeapSimulator::RunComputation( if (!shared) { VLOG(3) << " Allocating: " << buffer->ToString(); + alloc_size_by_instruction += size_fn_(*buffer); Alloc(buffer, instruction); } } // Account for the memory used by subcomputations when estimating the // current heap size. if (memory_by_computation_ != nullptr) { - algorithm_->AccountForSubcomputationMemory(instruction, - *memory_by_computation_); + algorithm_->AccountForSubcomputationMemory( + instruction, alloc_size_by_instruction, *memory_by_computation_); } // If all computations in the module have been scheduled, we can save memory @@ -323,7 +329,7 @@ Status HeapSimulator::RunComputation( to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const BufferValue* buffer = buffer_pending.first; - const FlatSet& pending = buffer_pending.second; + const flat_hash_set& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; to_free.push_back(buffer); @@ -345,7 +351,7 @@ HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), @@ -381,10 +387,8 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - const HloInstruction* instruction_to_calc_aliasing = - memory_by_computation_ == nullptr ? nullptr : instruction; - algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); - no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); + algorithm_->Alloc(buffer, size); + no_fragmentation_stats_->Alloc(buffer, size); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -522,21 +526,9 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } -void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - if (instruction == nullptr || - (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kCall && - instruction->opcode() != HloOpcode::kConditional)) { - Alloc(buffer, size); - } -} - void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. @@ -550,6 +542,14 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( } } } + if (max_subcomputation_bytes > 0 && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + max_subcomputation_bytes -= alloc_size_by_instruction; + } max_heap_size_ = std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } @@ -736,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() { return result_; } +void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + result_.chunk_map.emplace(buffer, Chunk{0, 0}); + return; + } + auto emplace_result = buffer_intervals_.emplace( + buffer, BufferInterval{buffer, size, current_time_, -1}); + DCHECK(emplace_result.second); + ++current_time_; +} + +void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + return; + } + BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer); + DCHECK_EQ(buffer_interval.buffer, buffer); + DCHECK_EQ(buffer_interval.size, size); + DCHECK_EQ(buffer_interval.end, -1); + buffer_interval.end = current_time_; + ++current_time_; +} + +namespace { + +// Node in BufferIntervalTree that stores the alloc and free times of a buffer, +// and the chunk assigned to it. +struct BufferIntervalTreeNode { + // Alloc time. + int64 start; + // Free time. + int64 end; + // Maximum free time of all nodes in the subtree where this node is the root. + int64 subtree_end; + // Allocated chunk for the buffer. + HeapSimulator::Chunk chunk; + // Left child. + BufferIntervalTreeNode* left; + // Right child. + BufferIntervalTreeNode* right; +}; + +// An interval tree that can query buffers overlapping in time. +class BufferIntervalTree { + public: + explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {} + + using Chunk = HeapSimulator::Chunk; + + // Adds a buffer to the interval tree, with the time interval and allocated + // chunk specified. + void Add(int64 start, int64 end, const Chunk& chunk) { + int index = node_count_; + DCHECK_LT(index, node_storage_.size()); + ++node_count_; + + node_storage_[index] = + BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr}; + + if (index == 0) { + // This is root. + return; + } + + BufferIntervalTreeNode* parent = &node_storage_[0]; + while (true) { + parent->subtree_end = std::max(parent->subtree_end, end); + if (parent->start > start) { + if (parent->left == nullptr) { + parent->left = &node_storage_[index]; + return; + } + parent = parent->left; + } else { + if (parent->right == nullptr) { + parent->right = &node_storage_[index]; + return; + } + parent = parent->right; + } + } + } + + // Returns vector of allocated chunks that overlap with the given time + // interval. + std::vector ChunksOverlappingInTime(int64 start, int64 end) { + std::vector result; + if (node_count_ == 0) { + return result; + } + std::vector visiting_stack; + visiting_stack.push_back(&node_storage_[0]); + while (!visiting_stack.empty()) { + BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + result.push_back(top->chunk); + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; + } + + private: + int64 node_count_ = 0; + std::vector node_storage_; +}; + +} // namespace + +HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { + std::vector sorted_buffer_intervals; + for (auto& entry : buffer_intervals_) { + sorted_buffer_intervals.push_back(entry.second); + } + std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); + + BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); + for (auto& buffer_interval : sorted_buffer_intervals) { + auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( + buffer_interval.start, buffer_interval.end); + std::sort( + chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); + + // Find the minimum free chunk that can hold this buffer. + Chunk min_fit_chunk{-1, INT64_MAX}; + auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) { + if (free_size < buffer_interval.size) { + return; + } + + if (free_size < min_fit_chunk.size) { + min_fit_chunk = {free_offset, free_size}; + } + }; + + int64 offset = 0; + for (auto& chunk : chunks_overlapping_in_time) { + if (offset < chunk.offset) { + use_free_chunk_if_smaller(offset, chunk.offset - offset); + } + offset = + std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_)); + } + use_free_chunk_if_smaller(offset, result_.heap_size - offset); + + if (min_fit_chunk.offset == -1) { + // Increase the heap size to fit in the last free chunk. + result_.heap_size = offset + buffer_interval.size; + min_fit_chunk = {offset, buffer_interval.size}; + } + + min_fit_chunk.size = buffer_interval.size; + const auto emplace_result = + result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk); + DCHECK(emplace_result.second); + + interval_tree.Add(buffer_interval.start, buffer_interval.end, + min_fit_chunk); + } + return result_; +} + +HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { + DCHECK(!algorithms_.empty()); + std::vector results(algorithms_.size()); + int64 min_size = INT64_MAX; + int min_size_index = -1; + for (int i = 0; i < algorithms_.size(); ++i) { + results[i] = algorithms_[i]->Finish(); + if (results[i].heap_size < min_size) { + min_size = results[i].heap_size; + min_size_index = i; + } + } + + DCHECK_GE(min_size_index, 0); + return results[min_size_index]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index ffbf947d5ad0cf598f9de9f98f5bbe344f095993..dbbf43082f2c1d21f5ef42f53804bf0969903a58 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,7 +58,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + absl::flat_hash_map chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -100,7 +100,7 @@ class HeapSimulator { const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given @@ -130,7 +130,7 @@ class HeapSimulator { const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); private: @@ -140,7 +140,7 @@ class HeapSimulator { HeapSimulator(std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule = nullptr, - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation = nullptr); ~HeapSimulator(); @@ -172,7 +172,7 @@ class HeapSimulator { // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. const HloSchedule* schedule_; - const tensorflow::gtl::FlatMap* + const absl::flat_hash_map* memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of @@ -193,12 +193,12 @@ class HeapSimulator { const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + absl::flat_hash_map> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + absl::flat_hash_set allocated_buffers_; + absl::flat_hash_set freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -218,12 +218,6 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; - // NoFragmentationStatsHeap overrides this method. - virtual void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - Alloc(buffer, size); - } - // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing // between computations entirely correctly. We are careful to not double count @@ -235,7 +229,9 @@ class HeapAlgorithm { // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + // The total number of bytes allocated by instruction. + int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) {} // Free de-allocates a previously allocated buffer. @@ -257,12 +253,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; - void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) override; - void AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map& memory_by_computation) override; void Free(const BufferValue* buffer, int64 size) override; @@ -351,6 +344,67 @@ class LazyBestFitHeap : public HeapAlgorithm { std::set free_; }; +// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, +// then allocates them in decreasing sizes regardless of the alloc/free time. It +// internally tracks the allocated buffers and their live intervals; when +// allocating a buffer, it finds the best-fit free chunk during its live +// interval. +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { + public: + GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {} + ~GlobalDecreasingSizeBestFitHeap() override {} + + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; + + private: + int64 alignment_; + Result result_; + + // The current time represented as an integer. It increments by 1 at each + // Alloc or Free call. + int64 current_time_ = 0; + + // BufferInterval stores a buffer's size and time interval. + struct BufferInterval { + const BufferValue* buffer; + int64 size; + // Alloc time of the buffer. + int64 start; + // Free time of the buffer. + int64 end; + }; + absl::flat_hash_map buffer_intervals_; +}; + +// A heap algorithm that chooses the best results from other algorithms added to +// it. +class ChooseBestHeapAlgorithm : public HeapAlgorithm { + public: + ChooseBestHeapAlgorithm( + std::unique_ptr>> algorithms) + : algorithms_(std::move(*algorithms)) {} + ~ChooseBestHeapAlgorithm() override {} + + void Alloc(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Alloc(buffer, size); + } + } + + void Free(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Free(buffer, size); + } + } + + Result Finish() override; + + private: + std::vector> algorithms_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 957c4a68915934796a315f2443c90e571e942e75..e30e7667f3015bc7bfe67c65147a5016332780f7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -98,6 +98,124 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } +TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { + // HloModule SubcomputationAccounting + + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[4]{0} constant({1, 1, 1, 1}) + // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0} + // %constant.1) + // } + + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} + // %reshape = f32[] reshape(f32[1]{0} %slice) + // %constant = f32[] constant(0) + // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // } + + // ENTRY %SubcomputationAccounting () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, + // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0} + // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1, + // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2), + // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0} + // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0} + // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewVerifiedModule(); + const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // reshape(slice(param)) != 0 + // Needs 5 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* slice = + cond_builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1})); + HloInstruction* reshape = + cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); + HloInstruction* zero = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction* cond_comparison = + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + HloInstruction* subtract = + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 1, 1, 1}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {1})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + auto entry_computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + std::vector cond_vec = {cond_param, slice, reshape, zero, + cond_comparison}; + std::vector while_body_vec = {body_param, one_vector, + subtract}; + std::vector entry_comp_vec = {while_init, while_loop, bcast, + matrix, transpose, add}; + schedule.set_sequence(cond_computation, cond_vec); + schedule.set_sequence(body_computation, while_body_vec); + schedule.set_sequence(entry_computation, entry_comp_vec); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + absl::flat_hash_map memory_by_computation; + memory_by_computation[cond_computation] = 5; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; @@ -174,7 +292,7 @@ class HeapSimulatorTracker { // Construct the module sequence grouped by computation. HloSchedule schedule(module_.get()); - tensorflow::gtl::FlatMap reverse_position; + absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) @@ -1021,5 +1139,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) { EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); } +class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(0, result.heap_size); + EXPECT_EQ(0, result.chunk_map.size()); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { + // space + // ^ + // | +---a---+ + // | +-------+ + // | +---c---+ + // | +-------+ + // | | b | + // | +-------+ + // | +-------+ + // | | | + // | | d | + // | +-------+ + // -----------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 30); + heap.Alloc(buffer_c_, 20); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 30); + heap.Free(buffer_c_, 20); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(100, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | | + // | | d | + // | +---a---+ +-------+ + // | + // | +-------+ + // | | | + // | | c | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 50); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 50); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(120, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | d | + // | +--a--+ +-------+ + // | +-------+ + // | | | + // | | c | + // | +-------+ + // | +-------+ + // | | | + // | | e | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 40); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 30); + heap.Alloc(buffer_e_, 50); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 40); + heap.Free(buffer_d_, 30); + heap.Free(buffer_e_, 50); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(140, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 93ec2c9438bf11b8119a947c4465926810129b7f..a0eb9e6ddcdb85737f85a4f4de36600648cd734b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 53 +// Next ID: 58 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -124,9 +124,13 @@ message HloInstructionProto { // The string representation of the infeed configuration. bytes infeed_config = 27; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a external target (eg, global symbol) to call, only present for + // kCustomCall. string custom_call_target = 28; + // Opaque string, only present for kCustomCall. + string custom_call_opaque = 53; + // Shape of outfeed request. xla.Shape outfeed_shape = 29; @@ -176,6 +180,17 @@ message HloInstructionProto { // Collective permute field. repeated SourceTarget source_target_pairs = 52; + + // Sharding for kDomain instructions. + xla.OpSharding domain_entry_sharding = 54; + xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated Shape operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -309,6 +324,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0986da65cbd3d550ecfa01212364518aba651d86..c3da12e273c77793647981f8653649155aac9483 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -119,7 +121,7 @@ class BufferValueMap { } // Return a set of all the values in the given buffer. - const tensorflow::gtl::FlatSet& GetValuesInBuffer( + const absl::flat_hash_set& GetValuesInBuffer( BufferNumber buffer_number) const { return buffers_.at(buffer_number); } @@ -142,7 +144,7 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet& old_value_set = + absl::flat_hash_set& old_value_set = buffers_.at(old_buffer_number); old_value_set.erase(&value); if (old_value_set.empty()) { @@ -290,13 +292,11 @@ class BufferValueMap { const HloDataflowAnalysis& dataflow_; // A map containing the set of values contained in each buffer. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffers_; // A map indicating which buffer each value is contained in. - tensorflow::gtl::FlatMap - value_to_buffer_number_; + absl::flat_hash_map value_to_buffer_number_; // The buffer number of the next buffer to be created. BufferNumber next_buffer_number_ = 0; @@ -352,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous( bool HloAliasAnalysis::InstructionBuffersAreDistinct( const HloInstruction* instruction) const { - tensorflow::gtl::FlatSet buffers_seen; + absl::flat_hash_set buffers_seen; for (const auto& pair : dataflow_analysis_->GetInstructionValueSet(instruction)) { const HloValueSet& value_set = pair.second; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index e345804537723f01e9ccb63e7d6ded1bd68f4196..372f99ff01c786a503e9fc2a1ba96fb4abf75b4c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -110,7 +111,7 @@ class HloAliasAnalysis { std::unique_ptr dataflow_analysis_; // A map indicating which buffer a value is contained in. - tensorflow::gtl::FlatMap value_to_buffer_; + absl::flat_hash_map value_to_buffer_; // A lazily constructed vector containing all HloBuffers sorted by // HloBuffer::Id. diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 6c11a073b74c61e44dfe81a32261ae78ae7b46fb..9c3aa0e64d119c2560f4955d0bcb492519fa52a2 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h index 658643b427a9625fac1166151a89cbd669f817d5..24910ca07bf7c991d31875704b5dd918ed04fe6f 100644 --- a/tensorflow/compiler/xla/service/hlo_clone_context.h +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -73,12 +73,12 @@ class HloCloneContext { return FindOrDie(computations_, old_computation); } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_instructions() const { return instructions_; } - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& cloned_computations() const { return computations_; } @@ -86,10 +86,8 @@ class HloCloneContext { private: HloModule* module_; string suffix_; - tensorflow::gtl::FlatMap - instructions_; - tensorflow::gtl::FlatMap - computations_; + absl::flat_hash_map instructions_; + absl::flat_hash_map computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8c6903d76628f87b01de044f1e49de367bf38110..c2041c466708fd8c88d34f14fbc0064905f594a9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -24,6 +24,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } -namespace { - -// Returns the new name for a fusion parameter when we change its number. -// -// Fusion parameters are named foo.param_1, bar.param_2, etc. We are -// renumbering the parameters, so replace the final number in the name with -// the updated value. -string RenameFusionParameter(const string& original_name, int64 new_param_no) { - const string param_underscore = ".param_"; - size_t index = original_name.rfind(param_underscore); - if (index == string::npos) { - return original_name; - } - string after_param = original_name.substr(index + param_underscore.size()); - int64 numeric_suffix; - if (absl::SimpleAtoi(after_param, &numeric_suffix)) { - return StrCat(original_name.substr(0, index + param_underscore.size()), - new_param_no); - } - return original_name; -} - -} // namespace - Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + param_no, param_instruction->shape(), StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() { if (removed > 0) { const int64 param_no = i - removed; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); - HloInstruction* new_instr = - AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + HloInstruction* new_instr = AddInstructionInternal( + HloInstruction::CreateParameter(param_no, param_instruction->shape(), + StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -272,18 +245,19 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { << "instruction " << instruction->name() << " has control successors and cannot be removed"; - TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); - auto inst_it = instruction_iterators_.at(instruction); - (*inst_it)->set_parent(nullptr); - instructions_.erase(inst_it); + auto inst_it = instruction_iterators_.find(instruction); + TF_RET_CHECK(inst_it != instruction_iterators_.end()); + (*inst_it->second)->set_parent(nullptr); + instructions_.erase(inst_it->second); + instruction_iterators_.erase(inst_it); return Status::OK(); } -void HloComputation::set_root_instruction( - HloInstruction* new_root_instruction) { +void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!IsFusionComputation()) { + if (!IsFusionComputation() && !accept_different_shape) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape() << " is incompatible with " @@ -304,10 +278,9 @@ void HloComputation::set_root_instruction( namespace { // Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder( - HloComputation* computation, - tensorflow::gtl::FlatSet* visited, - std::vector* post_order) { +void ComputeComputationPostOrder(HloComputation* computation, + absl::flat_hash_set* visited, + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -324,7 +297,7 @@ void ComputeComputationPostOrder( void HloComputation::ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) const { + absl::flat_hash_map* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -421,7 +394,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + absl::flat_hash_map visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -442,7 +415,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector HloComputation::MakeEmbeddedComputationsList() const { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; std::vector post_order; // To avoid special handling of this computation, cast away const of @@ -532,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map) { - tensorflow::gtl::FlatMap instruction_map; - tensorflow::gtl::FlatMap to_proto_id; + const absl::flat_hash_map& computation_map) { + absl::flat_hash_map instruction_map; + absl::flat_hash_map to_proto_id; std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { @@ -562,6 +535,28 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); + TF_RETURN_IF_ERROR([&]() -> Status { + std::vector parameters_seen(parameter_count); + int parameters_seen_count = 0; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) + << "Invalid parameter number. Expected [0, " << parameter_count + << "), got " << param_no; + TF_RET_CHECK(!parameters_seen[param_no]) + << "Parameter number " << param_no + << " already allocated in this computation"; + parameters_seen[param_no] = true; + parameters_seen_count++; + } + } + TF_RET_CHECK(parameters_seen_count == parameter_count) + << "Not all parameters in range [0, " << parameter_count + << ") were referenced"; + return Status::OK(); + }()); + auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -916,13 +911,14 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - context, suffix); + /*extras=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context, const string& suffix) { + absl::Span extras, HloCloneContext* context, + const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -944,6 +940,9 @@ std::unique_ptr HloComputation::CloneWithReplacements( VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; std::vector postorder; + for (HloInstruction* instr : extras) { + postorder.push_back(instr); + } for (HloInstruction* instr : MakeInstructionPostOrder()) { if (HloInstruction* replacement = replace(instr)) { postorder.push_back(replacement); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 91c5234a6fde6698c5d600d667e3370d44134a50..d87ab4bda162a74421e8906e07cfcb97e2128fe4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" @@ -40,8 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -134,9 +134,11 @@ class HloComputation { Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction - // must have already been added to the computation and have the same shape as - // the result of the computation for non fusion computations. - void set_root_instruction(HloInstruction* new_root_instruction); + // must have already been added to the computation. In addition it must have + // the same shape as the result of the computation for non fusion + // computations, except if accept_different_shape is set to true. + void set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape = false); // Return the root instruction of the computation. The root instruction is the // instruction which produces the output of the computation. @@ -186,7 +188,7 @@ class HloComputation { // calls. static StatusOr> CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& computation_map); // Gets the instructions in this computation. // @@ -225,7 +227,7 @@ class HloComputation { void UpdateReachabilityThroughInstruction( const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instructions_.size(); } + int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this // computation. This includes all embedded computations called directly or @@ -331,10 +333,13 @@ class HloComputation { // // If replacements maps a key to nullptr, we remove that instruction from the // new computation. + // If additional instructions are used by instructions in replacement map, + // they must be passed in post-order in the extras span. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloCloneContext* context = nullptr, const string& suffix = "clone"); + absl::Span extras, HloCloneContext* context = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -409,14 +414,14 @@ class HloComputation { // cross-replica-sum the union of the dependencies for all participating // instructions. using ChannelDependencyMap = - tensorflow::gtl::FlatMap>; + absl::flat_hash_map>; ChannelDependencyMap ComputeChannelDependencies() const; enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) const; + absl::flat_hash_map* visited) const; string name_; int64 unique_id_; @@ -434,7 +439,7 @@ class HloComputation { // instruction pointer to location in the list for fast lookup. using InstructionList = std::list>; InstructionList instructions_; - std::unordered_map + absl::flat_hash_map instruction_iterators_; std::vector param_instructions_; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index f837816cea78d78bb3d605dd91e81cac39036268..4f898ce61c3f36e83e4b13130a404dbb4a2c36c6 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,6 +76,26 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Don't constant fold unless it's a net positive or the output is small. + if (ShapeUtil::IsArray(instruction->shape())) { + int64 elements_in_removed_operands = 0; + for (HloInstruction* operand : instruction->operands()) { + if (operand->user_count() == 1 && + ShapeUtil::IsArray(operand->shape())) { + elements_in_removed_operands += + ShapeUtil::ElementsIn(operand->shape()); + } + } + int64 elements_in_constant = + ShapeUtil::ElementsIn(instruction->shape()); + + static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + if (elements_in_constant > elements_in_removed_operands && + elements_in_constant > kMaximumConstantSizeElements) { + continue; + } + } + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. @@ -84,6 +104,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { << instruction->ToString(); continue; } + VLOG(4) << "Constant folded: " << instruction->ToString(); TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(std::move(result)))); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 4557983a9c0b0006cc2189c96a88478d469475c1..4a624cc7b8483aaa834634185a23195e437bd4e4 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -23,7 +23,7 @@ namespace xla { // A pass which performs constant folding in order to avoid unnecessary // computation on constants. -class HloConstantFolding : public HloPassInterface { +class HloConstantFolding : public HloModulePass { public: absl::string_view name() const override { return "constant_folding"; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 3e0def5d26a0033d954a776c1c32d6c35acfb505..e45f905f7152c37a9ab2b41d407310671310c2a3 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } +const char* const kConstantFoldLargePad = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + b = f32[] constant(42) + ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 + })"; + +TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kConstantFoldLargePad)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b76c50bb5b99cf4c9e6d4e04c240e8159acfc338..b2005d3c210d4ae7e3702cb9624c3ad98056984c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -201,6 +202,44 @@ StatusOr MakeMapHlo(absl::Span operands, HloInstruction::CreateMap(map_shape, operands, map_computation)); } +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module) { + DCHECK_NE(nullptr, module); + std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::iota(all_dims.begin(), all_dims.end(), 0); + + auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); + HloComputation* reduce_computation; + { + HloComputation::Builder b(operand->name() + ".reduce_sub_computation"); + auto lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); + reduce_computation = module->AddEmbeddedComputation(b.Build()); + } + + return operand->parent()->AddInstruction(HloInstruction::CreateReduce( + scalar_shape, operand, init_value, all_dims, reduce_computation)); +} + +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + HloComputation* computation = pred->parent(); + DCHECK_EQ(computation, on_true->parent()); + DCHECK_EQ(computation, on_false->parent()); + TF_ASSIGN_OR_RETURN(Shape select_shape, + ShapeInference::InferTernaryOpShape( + HloOpcode::kSelect, pred, on_true, on_false)); + return computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, HloOpcode::kSelect, pred, on_true, on_false)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index b22058abb4dcbf17631f28e4eacf6c7f1da781d2..8e5ddbbd503a501bd493aec43a2ccd4db883ef0c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -107,6 +108,35 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, StatusOr MakeMapHlo(absl::Span operands, HloComputation* map_computation); +// Creates a Reduce HLO instruction and adds it to the computation containing +// the operand. This will create the sub-computation needed for the reduction in +// the given module. binary_opcode should represent a binary operation. +StatusOr MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module); + +// Creates a Select HLO instruction and adds it to the computation containing +// the predicate. The on_true and on_false instructions must also be contained +// in the same computation. +StatusOr MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false); + +// Creates an R1 Constant HLO instruction of the given PrimitiveType with the +// given values and adds it to the given computation. +template +StatusOr MakeR1ConstantHlo(HloComputation* computation, + PrimitiveType type, + absl::Span values) { + Literal literal = LiteralUtil::CreateR1(values); + if (literal.shape().element_type() != type) { + TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); + } + return computation->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); +} + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index b59c9ba3ed7990eb2a35abc83f87b25a1b1e7c60..e602107cbe64320a8e8e740168cb294ec6be9667 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -137,8 +137,8 @@ StatusOr HloCSE::Run(HloModule* module) { // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. - tensorflow::gtl::FlatSet + absl::flat_hash_set representatives(/*N=*/computation->instruction_count() + 1, &CseHash, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index a28c03599a8765da708f37b986010713654647cb..e4857fd3fdd9a329b013ac8215cb6d36d73c4b7d 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -25,7 +25,7 @@ namespace xla { // and identical instructions with the same operands are commoned. The pass // iterates over the instructions in topological order which enables the pass to // find arbitrarily large common expressions. -class HloCSE : public HloPassInterface { +class HloCSE : public HloModulePass { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 6a63681996bc57f4ef16b2405ffc8ce4f003e783..71122e73b1c78fe3cdbd86b9be073a55686444f1 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { - tensorflow::gtl::FlatSet visited; + absl::flat_hash_set visited; absl::InlinedVector stack; stack.push_back(inst); while (!stack.empty()) { @@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { void HloDataflowAnalysis::DeleteMarkedValues() { #ifndef NDEBUG // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); + absl::flat_hash_set id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -355,23 +356,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } -bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { - CHECK_EQ(slice->opcode(), HloOpcode::kSlice); - if (!slice->IsInPlaceSlice()) { - return false; - } - // If this slice is lowered to an in-place version, then it forwards the - // operand value to the output. - const InstructionValueSet& operand_set = - GetInstructionValueSet(slice->operand(0)); - InstructionValueSet& slice_set = GetInstructionValueSet(slice); - if (operand_set != slice_set) { - slice_set = operand_set; - return true; - } - return false; -} - bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -640,8 +624,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); - case HloOpcode::kSlice: - return UpdateSliceValueSet(instruction); case HloOpcode::kDomain: return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: @@ -673,7 +655,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; - tensorflow::gtl::FlatSet workset; + absl::flat_hash_set workset; auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { worklist.push(instruction); @@ -813,11 +795,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; - case HloOpcode::kSlice: - if (!instruction->IsInPlaceSlice()) { - define_all_values(); - } - break; case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1071,6 +1048,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index e62c1c2ac81981e1f44f4c7e1479107979576e32..abac398c04fc4c418d8814a0097db4434bc1cd9c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -182,7 +182,6 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); - bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 510d6360a1cf94ef06d2ed919a57c7a825886834..909853106d57d181e85e3e4134b4039be2b176f5 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2283,6 +2283,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -2308,7 +2346,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 1fe69b1395753a612499e6e87bfc22f8ac8e767b..401204267282b294ca9f701e29e9edd9f0f35b98 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -33,7 +33,7 @@ namespace xla { // // This pass does not remove dead parameter instructions, as parameter // instructions cannot be deleted. -class HloDCE : public HloPassInterface { +class HloDCE : public HloModulePass { public: ~HloDCE() override {} absl::string_view name() const override { return "dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index d36631fc2f16902ed8f1f89f903027081f9b3801..c0bf1b9e16b52d81365db277abeb06defeb12d44 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -30,7 +30,7 @@ namespace xla { // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no // kDomain instruction will be placed. -class HloDomainIsolator : public HloPassInterface { +class HloDomainIsolator : public HloModulePass { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 113fd18eae70f0a581e2ab3e44544c47fcab3361..c6d02f9f67bb599e496d20fc2acf2e627ed54438 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -40,18 +42,19 @@ namespace xla { return std::move(domain_map); } -bool HloDomainMap::InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const { +bool HloDomainMap::InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const { int64 domain_id1 = GetDomainId(instruction1); int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } -int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } -int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainMetadataId( + const HloInstruction* instruction) const { return FindOrDie(domain_metadata_id_, instruction); } @@ -106,8 +109,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() { auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { return a->Matches(*b); }; - tensorflow::gtl::FlatMap + absl::flat_hash_map domain_metadata(1024, hash, equal); for (auto& domain : instruction_domains_) { @@ -198,7 +201,8 @@ StatusOr> HloDomainMap::CreateDomain( return std::move(domain); } -bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { +bool HloDomainMap::IsDomainInstruction( + const HloInstruction* instruction) const { if (instruction->opcode() != HloOpcode::kDomain) { return false; } @@ -216,7 +220,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set, + const absl::flat_hash_set& instruction_set, const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 56b557d7cea424f63cd4891661ae446133ee5a37..bce7d1aa7cf1822ef1608674e7bf9483c628e4b5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,27 +58,26 @@ class HloDomainMap { } // Checks whether two instructions are within the same domain. - bool InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const; + bool InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const; // Checks whether instruction is a kDomain instruction of the kind we are // currently processing. - bool IsDomainInstruction(HloInstruction* instruction) const; + bool IsDomainInstruction(const HloInstruction* instruction) const; // Retrieves the domain identifier of the instruction, or -1 in case // instruction is not found within any domain. - int64 GetDomainId(HloInstruction* instruction) const; + int64 GetDomainId(const HloInstruction* instruction) const; // Returns the unique id of the domain metadata for the domain the given // instruction belongs to. The given instruction must not be a kDomain // instruction since each domain instruction is associated with 2 domains. - int64 GetDomainMetadataId(HloInstruction* instruction) const; + int64 GetDomainMetadataId(const HloInstruction* instruction) const; private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. - using InstructionOrderMap = - tensorflow::gtl::FlatMap; + using InstructionOrderMap = absl::flat_hash_map; HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} @@ -111,7 +110,7 @@ class HloDomainMap { // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set, + const absl::flat_hash_set& instruction_set, const InstructionOrderMap& instructions_order); // Populates domain_metadata_id_ that maps each HloInstruction to the unique @@ -120,8 +119,8 @@ class HloDomainMap { string domain_kind_; std::vector> instruction_domains_; - tensorflow::gtl::FlatMap instruction_to_domain_; - tensorflow::gtl::FlatMap domain_metadata_id_; + absl::flat_hash_map instruction_to_domain_; + absl::flat_hash_map domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 302807f816e4ab626af419023e7740fd6bde795f..d3c83c15ae3be67a64f3dc4bcb0312ae9fbc33e4 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -42,7 +42,7 @@ class DomainMetadata { // operand/user pathways, without crossing a kDomain instruction of a given // kind. The reach_set can contain kDomain instructions of other kinds, if // two domains of different kind intersect each other. - tensorflow::gtl::FlatSet reach_set; + absl::flat_hash_set reach_set; // The same instructions in reach_set, but purged from kDomain instructions // and ordered according to their computation graph post-order, i.e. @@ -55,8 +55,8 @@ class DomainMetadata { // whose dataflow enters the reach set (domain), while the exit_domains // contains the set of kDomain instructions whose dataflow exit the reach // set. - tensorflow::gtl::FlatSet enter_domains; - tensorflow::gtl::FlatSet exit_domains; + absl::flat_hash_set enter_domains; + absl::flat_hash_set exit_domains; }; virtual ~DomainMetadata() = default; diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 97bc8ef604092acc849b55b09af8a24bf775529e..0fc30fb86c337a8bba5957d504caa7deeac9b86c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -26,7 +26,7 @@ namespace xla { // Removes all the kDomain instructions of a given kind from the input module, // and calls the normalizer to propagate the properties on the possibly new born // instructions. -class HloDomainRemover : public HloPassInterface { +class HloDomainRemover : public HloModulePass { public: // Creates a new HloDomainRemover object tasked at removing all the kDomain // instructions of a given kind. diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 81d6d69a8c59da2fc77cb2bab808602cd964fdaf..bea5cba38d018029c9805e1593fadad54460447e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -29,7 +29,7 @@ namespace xla { // Verifies that the domain instructions are consistent, and the each domain is // surrounded by the same metadata. -class HloDomainVerifier : public HloPassInterface { +class HloDomainVerifier : public HloModulePass { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 44ded2c2faf7c38d1e2f2aae577ddc07089bbb6a..4d2a942925288ba4c3977ffcd25b55746a555a5e 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -25,7 +25,7 @@ namespace xla { // inserting Convert ops. This allows a backend to support an element type while // only actually implementing the Convert op for that element type. This is // generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloPassInterface { +class HloElementTypeConverter : public HloModulePass { public: // eliminate_type is the type to eliminate as the input or output of ops, // using Convert ops to replace it with replace_with_type. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index b1bd9b7ae9f2f01eac09636e75ce6173fb76fe1c..1be91302d16c346d26f370617ffd1fcd8a5c5ed6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -496,6 +497,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { return Status::OK(); } +Status HloEvaluator::HandleReal(HloInstruction* real) { + auto operand = real->operand(0); + switch (operand->shape().element_type()) { + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](bfloat16 elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex64 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](Eigen::half elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](float elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](double elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleImag(HloInstruction* imag) { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + Status HloEvaluator::HandleCompare(HloInstruction* compare) { HloOpcode opcode = compare->opcode(); auto lhs = compare->operand(0); @@ -1170,80 +1226,87 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " - << rank; TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; - // We need to sort and array of keys and an array of values, where the + // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the // array using the keys, and then go back to pair-of-arrays. VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - auto sort_r1 = [](const Literal& keys_literal, - const Literal& values_literal) { - const auto& keys_data = keys_literal.data(); - const auto& values_data = values_literal.data(); - - using kv_pair = std::pair; - std::vector key_value_vector; - CHECK_EQ(keys_data.size(), values_data.size()); - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); - } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - std::vector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - Literal result_keys_literal(keys_literal.shape()); - result_keys_literal.PopulateR1(absl::Span(result_keys)); - Literal result_values_literal(values_literal.shape()); - result_values_literal.PopulateR1( - absl::Span(result_values)); - return std::make_pair(std::move(result_keys_literal), - std::move(result_values_literal)); - }; - - Literal result_tuple; - if (rank == 1) { - auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = - LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); - int64 r1_length = keys_literal.shape().dimensions(1); - for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto keys_r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - TF_ASSIGN_OR_RETURN(auto values_r1_slice, - values_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); - TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first.Reshape({1, r1_length})); - TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values, {0, 0}, {row, 0}, {1, r1_length})); - } - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); + if (rank == 0) { + // Nothing to sort. + return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); } + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys_literal.shape(), zero_base, + AsInt64Slice(keys_literal.shape().dimensions()), increment, + [&](absl::Span indices) -> StatusOr { + // Extract a slice from the keys and values literals that correspond to + // exactly the row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto keys_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& keys_data = keys_to_sort.data(); + TF_ASSIGN_OR_RETURN(auto values_to_sort, + values_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& values_data = values_to_sort.data(); + using kv_pair = std::pair; + std::vector key_value_vector; + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + Literal sorted_keys(ShapeUtil::MakeShape( + keys_literal.shape().element_type(), {sort_dim_elements})); + sorted_keys.PopulateR1(absl::Span(result_keys)); + Literal sorted_values(ShapeUtil::MakeShape( + values_literal.shape().element_type(), {sort_dim_elements})); + sorted_values.PopulateR1(absl::Span(result_values)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector start_indices(rank, 0); + TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, + sorted_keys.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys_reshaped, start_indices, indices, slice_dimensions)); + TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, + sorted_values.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + + Literal result_tuple; + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } @@ -1253,6 +1316,9 @@ StatusOr EvaluateSortCurried(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { + case PRED: + return EvaluateSortInternal(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal(sort, keys_literal, values_literal); @@ -1289,15 +1355,6 @@ StatusOr EvaluateSort(HloInstruction* sort, } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - const int64 sort_dim = sort->dimensions(0); - const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); - if (sort_dim != rank - 1) { - return Unimplemented( - "Trying to sort along dimension %d, which is not the last " - "dimension", - sort_dim); - } - if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { @@ -1324,7 +1381,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { "unsupported"); } } - return reduce->Visit(typed_visitors_.at(first_element_type).get()); + return reduce->Visit(typed_visitors_[first_element_type].get()); } } @@ -1336,6 +1393,12 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 21e676d671af08d1626ca6f157db63bf8d23ae0b..07f8d0aad4af0b07303b4e485b3630cc75bcb519 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -134,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. Status DefaultAction(HloInstruction* hlo) override { - return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); } Status Preprocess(HloInstruction* hlo) override; @@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; + Status HandleReal(HloInstruction* real) override; + + Status HandleImag(HloInstruction* imag) override; + Status HandleReduce(HloInstruction* reduce) override; // Returns the already-evaluated literal result for the instruction. @@ -206,8 +210,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // post-orderring. // Must be cleared for each evaluation. // Storing Literal in place require the container to have pointer stability so - // we cannot use FlatMap any more. - std::unordered_map evaluated_; + // we cannot use flat_hash_map any more. + absl::node_hash_map evaluated_; private: template @@ -237,12 +241,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; + std::unique_ptr typed_visitors_[PrimitiveType_ARRAYSIZE]; // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index c758bb2240a237f02d0f705c3f6f360da92f68c2..ce83126cb2cde93e29682a1d7cd44c2fbaec4fb6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -65,6 +65,20 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, return evaluator_->Evaluate(*module().entry_computation(), arg_literals); } + // Evaluate function that takes in a local module instead of using module_ + // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // removed, this should be the default Evaluate function. + Literal EvaluateWithModule( + HloModule* module, absl::Span arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(module).ValueOrDie(); + } + return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } + std::unique_ptr evaluator_; void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, @@ -1448,6 +1462,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); @@ -2554,6 +2620,114 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_NegativeIndices + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the negative indices. + Literal scatter_indices = LiteralUtil::CreateR1({-1, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { + const string hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the OOB indices. + Literal scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + // Given the update window size of 2,2 and the index of 0,2, the update window + // will be OOB. So, nothing should be updated. + Literal expected = operand.Clone(); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { @@ -2673,6 +2847,25 @@ ENTRY main (pred: pred[]) -> f32[]{ EXPECT_FALSE(statusor.status().ok()); } +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 7f090a52db33a9cfc83b67a07d613ce2fe5f7e9e..84fbbd3e0c3ddb704b8db601897f3b199dc99626 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include + #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" @@ -41,7 +43,9 @@ template using is_complex64_t = std::is_same; // It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template ::value || - std::is_same::value>::type* = nullptr> +template ::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template ::value>::type* = nullptr> +template ::value || + std::is_same::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast(a), static_cast(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -78,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. +// - HandleImag and HandleReal: where the resulting literal type is always float +// and the operand is always complex, or real in the case of HandleReal. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. // @@ -249,12 +262,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -265,11 +273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = result.Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -327,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor(floor); } - Status HandleImag(HloInstruction* imag) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], - ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { - return std::imag(elem_operand); - })); - return Status::OK(); - } - Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -682,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleReal(HloInstruction* real) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], - ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { - return std::real(elem_operand); - })); - return Status::OK(); - } - template ::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { @@ -1084,66 +1072,66 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { + // Find corresponding spatial dimension index for input (lhs). + int64 lhs_linear_spatial_index = 0; + int64 rhs_linear_spatial_index = 0; + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + lhs_linear_spatial_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + rhs_linear_spatial_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { const int64 iz = feature_group_index * input_feature_group_size + rhs_iz; - int64 lhs_linear_index = 0; + int64 lhs_linear_index = lhs_linear_spatial_index; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = 0; + int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - result_val += static_cast(lhs_literal_data[lhs_linear_index]) * static_cast(rhs_literal_data[rhs_linear_index]); @@ -1536,47 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data(); - - std::vector result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess(a, b); - }); - Literal result_literal(keys_literal.shape()); - result_literal.PopulateR1(absl::Span(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal result_literal(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result = sort_r1(r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector zero_base(rank, 0); + std::vector increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span indices) -> StatusOr { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data(); + + std::vector result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span(result_data)); + std::vector slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } @@ -2274,19 +2270,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // be 1. int64 update_dim_size = update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); - // Clamp the scatter index so that the scatter region fits in the - // operand. input_scatter_index_clamped[i] = - // clamp(input_scatter_index[i], 0, - // operand_shape.dimensions(i) - - // update_dim_size); - input_scatter_index_clamped[i] = - std::min(operand_shape.dimensions(i) - update_dim_size, - std::max(0LL, input_scatter_index[i])); + // If any part of the update region is out-of-bounds, then do not + // perform any update on the input. + if ((input_scatter_index[i] < 0) || + (input_scatter_index[i] > + operand_shape.dimensions(i) - update_dim_size)) { + return true; + } } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; - DCHECK_GE(input_index[i], 0); - DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + input_index[i] = input_scatter_index[i] + input_window_index[i]; } auto result_value_literal = @@ -2350,8 +2343,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); + Literal result(shape); TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); @@ -2621,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index de3d7a167752f0de790585e50874dd6d2904bd37..ce4cad42355ec5881f2ae14f4dd52a0588d51cf7 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -90,8 +90,9 @@ std::unique_ptr CreateHloProfilePrinterData( HloInstructionInfo* instruction_info = computation_info->add_instruction_infos(); instruction_info->set_long_name(hlo->ToString()); - instruction_info->set_short_name( - hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_short_name(hlo->ToString( + HloPrintOptions().set_compact_operands(true).set_print_operand_names( + false))); instruction_info->set_category(hlo->ToCategory()); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); instruction_info->set_transcendental_count( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 287ba84b3b24d3ec6dc21d157205ebc6a987c7d7..13a74fd8a115c5dc9a9518b226dfee4445cc7180 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1110,7 +1110,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { instr->metadata().source_line())); } - return StrJoin(lines, "
"); + return StrJoin(lines, "\n"); } string HloDotDumper::GetInstructionNodeBackendConfig( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e905f2983a43189eeb06824cf3078c235ab07925..c317e9e3b48d088592d6c5c0b914004e1ccb3d63 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" @@ -37,14 +39,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" @@ -59,8 +60,8 @@ using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map) { + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -80,6 +81,20 @@ StatusOr> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + + TF_RET_CHECK(std::all_of( + proto.operand_ids().begin(), proto.operand_ids().end(), + [&instruction_map](int64 id) { return instruction_map.contains(id); })) + << proto.name() << " instruction contains invalid operand id(s)"; + + TF_RET_CHECK(std::all_of( + proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), + [&computation_map](int64 id) { return computation_map.contains(id); })) + << proto.name() << " instruction references invalid computation id(s)"; + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -180,17 +195,16 @@ StatusOr> HloInstruction::CreateFromProto( } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + proto.shape(), proto.dimensions(0), keys, + absl::Span(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -266,7 +280,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + auto* fused_computation = + tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), @@ -289,6 +304,9 @@ StatusOr> HloInstruction::CreateFromProto( proto.tuple_index()); break; case HloOpcode::kReducePrecision: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); @@ -296,12 +314,18 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), operands(1), proto.outfeed_config()); break; @@ -331,6 +355,9 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { @@ -378,8 +405,22 @@ StatusOr> HloInstruction::CreateFromProto( operands(1), operands(2), computations(1)); break; case HloOpcode::kCustomCall: - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target()); + if (proto.constrain_layout()) { + // A proto RepeatedPtrField cannot be converted to a Span (it is a + // vector of pointers essentially) so create a vector of shapes to pass + // in. + std::vector operand_shapes; + for (const Shape& shape : proto.operand_shapes_with_layout()) { + operand_shapes.push_back(shape); + } + instruction = CreateCustomCall( + proto.shape(), all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); + } else { + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target(), + proto.custom_call_opaque()); + } if (proto.has_window()) { static_cast(instruction.get()) ->set_window(proto.window()); @@ -446,8 +487,8 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kIota: - TF_RET_CHECK(proto.dimensions_size() <= 1) - << "Iota instruction should have at most 1 dimension but sees " + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; @@ -465,31 +506,34 @@ StatusOr> HloInstruction::CreateFromProto( proto.dot_dimension_numbers(), precision_config); break; } - case HloOpcode::kDomain: + case HloOpcode::kDomain: { TF_RET_CHECK(proto.operand_ids_size() == 1) << "Domain instruction should have 1 operands but sees " << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_domain_entry_sharding()) + << "Domain instruction must domain_entry_sharding"; + TF_RET_CHECK(proto.has_domain_exit_sharding()) + << "Domain instruction must domain_exit_sharding"; + TF_ASSIGN_OR_RETURN( + HloSharding entry_hlo_sharding, + HloSharding::FromProto(proto.domain_entry_sharding())); + TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding, + HloSharding::FromProto(proto.domain_exit_sharding())); instruction = absl::make_unique( - proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, - /*user_side_metadata=*/nullptr); + proto.shape(), operands(0), + absl::make_unique( + std::make_shared(entry_hlo_sharding)), + absl::make_unique( + std::make_shared(exit_hlo_sharding))); break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; instruction->AppendOperand(instruction_map.at(operand_id)); } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } if (instruction->opcode() != HloOpcode::kFusion) { for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; instruction->called_computations_.push_back( computation_map.at(computation_id)); } @@ -501,6 +545,13 @@ StatusOr> HloInstruction::CreateFromProto( } } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); @@ -1026,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span values) { return absl::make_unique(shape, dimension, keys, values); } @@ -1108,9 +1159,18 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target) { - return absl::make_unique(shape, operands, - custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque); +} + +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque, operand_shapes_with_layout); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1431,7 +1491,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; - tensorflow::gtl::FlatSet seen; + absl::flat_hash_set seen; for (HloInstruction* operand : operands()) { if (seen.insert(operand).second) { unique.push_back(operand); @@ -2005,7 +2065,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( options.is_in_nested_computation()) { str.push_back(PrintName( canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { + } else if (options.print_operand_names()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); @@ -2660,14 +2720,14 @@ class HloInstruction::FusionReusesParamElements { // the value of this parameter, which would save stack space but not allow us // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { - tensorflow::gtl::FlatMap memoization_cache; + absl::flat_hash_map memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); } private: static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, - tensorflow::gtl::FlatMap* cache) { + absl::flat_hash_map* cache) { if (auto hlo_param = DynCast(&hlo)) { if (hlo_param->parameter_number() == i) { return UseKind::kUse; @@ -2910,6 +2970,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } +bool HloPtrComparator::operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } + auto lhs_module = lhs->GetModule(); + auto rhs_module = rhs->GetModule(); + CHECK((lhs_module == nullptr && rhs_module == nullptr) || + (lhs_module != nullptr && rhs_module != nullptr)); + if (lhs_module != nullptr && + lhs_module->unique_id() != rhs_module->unique_id()) { + return lhs_module->unique_id() < rhs_module->unique_id(); + } + return lhs->unique_id() < rhs->unique_id(); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3027,10 +3107,6 @@ const std::vector& HloInstruction::slice_strides() const { return Cast(this)->slice_strides(); } -bool HloInstruction::IsInPlaceSlice() const { - return Cast(this)->IsInPlaceSlice(); -} - const Literal& HloInstruction::literal() const { return Cast(this)->literal(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 4f6cac1396c16beb5cebf909032dead711d77a61..93ff04b1e4f9fc5862d1f3e8b5ebb769241761b1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -80,6 +80,7 @@ class HloPrintOptions { print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), + print_operand_names_(true), print_program_shape_(true), print_percent_(true), print_control_dependencies_(true), @@ -107,6 +108,7 @@ class HloPrintOptions { .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) + .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) @@ -144,6 +146,12 @@ class HloPrintOptions { return *this; } + // If true, the operand names will be printed. + HloPrintOptions& set_print_operand_names(bool value) { + print_operand_names_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -162,8 +170,8 @@ class HloPrintOptions { return *this; } - // If true, only a part of operands will be printed out, and their names will - // be omitted (note that in this case the text will not be parsable). + // If true, only a part of operands will be printed out (note that in this + // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; @@ -197,6 +205,7 @@ class HloPrintOptions { bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } + bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { @@ -215,6 +224,7 @@ class HloPrintOptions { bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; + bool print_operand_names_; bool print_program_shape_; bool print_percent_; bool print_control_dependencies_; @@ -247,7 +257,7 @@ class CanonicalNameMap { private: int64 index; - tensorflow::gtl::FlatMap canonical_name_map; + absl::flat_hash_map canonical_name_map; }; // HLO instructions are the atomic unit of the high-level compiler's IR. @@ -350,8 +360,8 @@ class HloInstruction { // calls. static StatusOr> CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map); + const absl::flat_hash_map& instruction_map, + const absl::flat_hash_map& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -660,10 +670,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and an optional values operand. + // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -718,10 +728,21 @@ class HloInstruction { HloComputation* computation); // Creates a custom call instruction that applies the given custom call target - // to the given operands. "shape" is the resultant shape. + // to the given operands. "opaque" can be an arbitrary string with a + // backend-specific interpretation. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque = ""); + + // Overload which constrains the layouts of the operand and result. 'shape' + // and 'operand_shapes_with_layout' must have layouts. + // 'operand_shapes_with_layout' must have a compatible element for each + // operand. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -1319,9 +1340,6 @@ class HloInstruction { int64 slice_strides(int64 dimension) const; const std::vector& slice_strides() const; - // Delegates to HloSliceInstruction::IsInPlaceSlice. - bool IsInPlaceSlice() const; - // Returns the literal associated with this instruction. const Literal& literal() const; @@ -1616,6 +1634,10 @@ class HloInstruction { InstructionVector operands_; // The set of control predecessors of this instruction. + // Note that the order of the instructions in the vector influences the order + // computed in HloComputation::ComputeInstructionPostOrder, which may + // influence the result of the compilation by changing the scheduling. We are + // not sure if it matters. std::vector control_predecessors_; // The users of this instruction. Users are HLOs where this instruction is an @@ -1689,21 +1711,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. -// -// Note that this cannot be used for HLO instructions across multiple modules -// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, - const HloInstruction* const& rhs) const { - if (rhs == nullptr) { - // Nothing compares less than nullptr. - return false; - } - if (lhs == nullptr) { - return true; - } - return lhs->unique_id() < rhs->unique_id(); - } + const HloInstruction* const& rhs) const; }; template diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c1b7c3832b44b5d65b715dffa5211a5c92e17953..d93351fe0435b5f29035dc4ea0621a8c576bfd5a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e92882c22a6ef1dd43440d3c94c7d233c9a4fb5d..179ace2cdb76051fecdeb7e0cbdcd808bf9fee25 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" @@ -27,8 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_channel_id(channel_id_); + proto.set_is_host_transfer(is_host_transfer_); return proto; } @@ -598,11 +600,11 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) + absl::Span values) : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { AppendOperand(keys); - if (values) { - AppendOperand(values); + for (auto* value : values) { + AppendOperand(value); } } @@ -631,9 +633,8 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; - HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; return absl::make_unique(shape, dimensions(0), keys, - values); + new_operands.subspan(1)); } HloTransposeInstruction::HloTransposeInstruction( @@ -641,14 +642,6 @@ HloTransposeInstruction::HloTransposeInstruction( absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -1042,7 +1035,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( const int64 param_no = operand_count(); // Name the parameter after the instruction it represents in the outer // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); + // string param_name = StrCat(new_operand->name(), ".param_", param_no); + string param_name = StrCat("param_", param_no); HloInstruction* fused_parameter = fused_instructions_computation()->AddParameter( HloInstruction::CreateParameter(param_no, new_operand->shape(), @@ -1098,7 +1092,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( // Note that we add the unfused instructions to this->parent_ computation. // This is necessary because the unique_id needs for an instruction and // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; + absl::flat_hash_map old_to_new; std::vector unfused_instructions; auto computation_to_merge = instruction_to_merge->fused_instructions_computation(); @@ -1391,7 +1385,7 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( } Status HloFusionInstruction::DeduplicateFusionOperands() { - tensorflow::gtl::FlatMap operand_indices; + absl::flat_hash_map operand_indices; std::vector operands_to_remove; for (int i = 0; i < operand_count(); ++i) { auto emplace_result = operand_indices.emplace(operand(i), i); @@ -1488,7 +1482,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloGetTupleElementInstruction::HloGetTupleElementInstruction( const Shape& shape, HloInstruction* operand, int64 index) : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); AppendOperand(operand); } @@ -1610,9 +1603,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); AppendOperand(token_operand); } @@ -1830,10 +1820,28 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span operands, - absl::string_view custom_call_target) + absl::string_view custom_call_target, absl::string_view opaque) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), - feature_group_count_(1) { + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(false) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque, + absl::Span operand_shapes_with_layout) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(true), + operand_shapes_with_layout_(operand_shapes_with_layout.begin(), + operand_shapes_with_layout.end()) { for (auto operand : operands) { AppendOperand(operand); } @@ -1849,7 +1857,14 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + if (layout_constrained()) { + proto.set_constrain_layout(true); + for (const Shape& shape : operand_shapes_with_layout_) { + *proto.add_operand_shapes_with_layout() = shape; + } + } return proto; } @@ -1872,6 +1887,19 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( // an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + // If the opaque string becomes enormous we may want to reconsider printing + // this inline and consider other options. + if (!opaque_.empty()) { + extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); + } + if (layout_constrained()) { + std::vector shape_strings; + for (const Shape& shape : operand_shapes_with_layout_) { + shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape)); + } + extra.push_back(StrCat("operand_layout_constraints={", + StrJoin(shape_strings, ", "), "}")); + } return extra; } @@ -1897,7 +1925,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } - return custom_call_target_ == casted_other.custom_call_target_; + return custom_call_target_ == casted_other.custom_call_target_ && + opaque_ == casted_other.opaque_; } std::unique_ptr @@ -1905,7 +1934,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { auto cloned = absl::make_unique( - shape, new_operands, custom_call_target()); + shape, new_operands, custom_call_target(), opaque()); if (window_ != nullptr) { cloned->set_window(*window_); } @@ -2301,4 +2330,23 @@ std::unique_ptr HloDomainInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); } + +HloInstructionProto HloDomainInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + auto operand_side_sharding = + dynamic_cast(operand_side_metadata_.get()); + if (operand_side_sharding) { + *proto.mutable_domain_entry_sharding() = + operand_side_sharding->sharding()->ToProto(); + } + + auto user_side_sharding = + dynamic_cast(user_side_metadata_.get()); + if (user_side_sharding) { + *proto.mutable_domain_exit_sharding() = + user_side_sharding->sharding()->ToProto(); + } + + return proto; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 2d7bc83855e761ed313d831a1252a54130910bbe..3a0b7490dc70b010fc360425cf1a71103c370eb8 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -418,7 +418,7 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -546,17 +546,6 @@ class HloSliceInstruction : public HloInstruction { } const std::vector& slice_strides() const { return slice_strides_; } - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -573,9 +562,6 @@ class HloSliceInstruction : public HloInstruction { std::vector slice_starts_; std::vector slice_limits_; std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; }; class HloConstantInstruction : public HloInstruction { @@ -910,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. @@ -1068,9 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction { class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction(const Shape& shape, - absl::Span operands, - absl::string_view custom_call_target); + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1090,6 +1085,7 @@ class HloCustomCallInstruction : public HloInstruction { convolution_dimension_numbers_ = absl::make_unique(dnums); } + const string& opaque() const { return opaque_; } const string& custom_call_target() const { return custom_call_target_; } void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; @@ -1098,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1109,14 +1115,21 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a global symbol to call. string custom_call_target_; + // Opaque string interpreted by the backend. + string opaque_; // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { @@ -1337,6 +1350,9 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata); + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + // Retrieves the operand side metadata of a kDomain instruction. const DomainMetadata& operand_side_metadata() const { return *operand_side_metadata_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 3a1dd471c626ae9497cfcca62c30736bcdbb2b38..5bf055f3c012fef687cdc275d62efdf2d4cd5e5c 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers( } } +// Makes sure that if a live instruction is within a computation used in control +// flow operations, we mark live even other related instructions. +void PropagateLivenessThroughControlFlow( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + HloInstruction* caller = callsite.instruction(); + if (caller->opcode() == HloOpcode::kWhile) { + // If a live instruction is within the %while body or condition + // computation, mark the predicate value returned by the condition + // computation live as well. + MarkLiveAtIndex(caller->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); + } else if (caller->opcode() == HloOpcode::kConditional) { + // If a live instruction is within the true or false branches of a + // conditional, we mark the predicate operand live as well. + MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist, + workset); + } + } + } +} + } // namespace HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) @@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() { } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kWhile && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kWhile) { PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kParameter) { PropagateLivenessToParameterCallers(instruction, &live_index_map_, &worklist, &workset, call_graph_.get()); @@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() { MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); } } + PropagateLivenessThroughControlFlow(instruction, &live_index_map_, + &worklist, &workset, call_graph_.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 01b625c29ca2823b2a2490b30a9d4d5128b4c22e..e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); } +TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + InnerWhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + InnerWhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + OuterWhileCondition { + cond_param.2 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 + constant.5 = s32[] constant(5) + ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + } + OuterWhileBody { + body_param.2 = (s32[]) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0 + constant.6 = s32[] constant(0) + tuple.2 = (s32[]) tuple(constant.6) + inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition, + body=InnerWhileBody + constant.7 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.8, constant.7) + ROOT rtuple = (s32[]) tuple(add.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=OuterWhileCondition, + body=OuterWhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index c7ec88d450712b0831971139f165934ef5524845..5cee865b7ad34eded1743d9d5455bb40febf6182 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -74,7 +76,7 @@ class ListScheduler { const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { ListScheduler scheduler(computation, points_to_analysis, size_function, memory_by_computation); @@ -99,7 +101,7 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), @@ -110,7 +112,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet instr_uses; + absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( [&](const ShapeIndex& /*index*/, @@ -193,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -221,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -234,8 +249,7 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - tensorflow::gtl::FlatMap - unscheduled_pred_count; + absl::flat_hash_map unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -251,8 +265,8 @@ class ListScheduler { std::multimap ready_queue; // Map of ready instructions to their iterators in ready_queue. - tensorflow::gtl::FlatMap::iterator> + absl::flat_hash_map::iterator> ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { @@ -262,9 +276,8 @@ class ListScheduler { }; for (auto* instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction) == 0) { + if (instruction->operands().empty() && + instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); } } @@ -347,21 +360,19 @@ class ListScheduler { // Computations are analyzed in post-order. When scheduling an instruction // that includes subcomputations, such as a while loop, we use this map to // look up the memory needed by subcomputations. - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map, and that the map - // entries are std::pair's. - std::unordered_map unscheduled_use_count_; + // LogicalBuffer. + absl::flat_hash_map unscheduled_use_count_; // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet scheduled_instructions_; + absl::flat_hash_set scheduled_instructions_; }; int64 SumLogicalBufferSizes( @@ -379,7 +390,7 @@ StatusOr ScheduleComputationHelper( const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { VLOG(2) << "Computation: " << computation.name(); if (algorithm) { @@ -396,13 +407,13 @@ StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; + int64 total_hlos = computation.parent()->instruction_count(); + absl::flat_hash_map extra_users; + absl::flat_hash_map total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; @@ -419,7 +430,7 @@ StatusOr DFSMemoryScheduler( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet unique_operands( + absl::flat_hash_set unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; @@ -467,7 +478,7 @@ StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { return ListScheduler::Run(computation, points_to_analysis, size_function, memory_by_computation); @@ -477,7 +488,7 @@ StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { return HloInstructionSequence(computation.MakeInstructionPostOrder()); } @@ -486,7 +497,7 @@ StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -549,7 +560,7 @@ StatusOr ScheduleModule( HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - tensorflow::gtl::FlatMap memory_by_computation; + absl::flat_hash_map memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, @@ -577,7 +588,7 @@ StatusOr ScheduleComputation( CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); - tensorflow::gtl::FlatMap empty_map; + absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 5e02868ebadaf06458f81e4f10ac04f882421ec8..a4c1d3db8170a1725043def576f913e09b352e5d 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -37,7 +38,7 @@ namespace xla { typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, - const tensorflow::gtl::FlatMap&)> + const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler @@ -45,7 +46,7 @@ StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // DFS-order scheduler @@ -53,7 +54,7 @@ StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // Naive Post Order scheduler @@ -61,7 +62,7 @@ StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler @@ -71,7 +72,7 @@ StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap& + const absl::flat_hash_map& memory_by_computation); // Returns an HloSchedule which seeks to minimize the memory required for @@ -90,7 +91,7 @@ StatusOr ScheduleComputation( // A pass which schedules the HLO instructions in a module. The HloModule's // schedule field is set to the resulting HloSchedule using // HloModule::set_schedule. -class HloMemoryScheduler : public HloPassInterface { +class HloMemoryScheduler : public HloModulePass { public: // size_function is the function returning the number of bytes required for a // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not @@ -109,7 +110,7 @@ class HloMemoryScheduler : public HloPassInterface { // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. -class HloDescheduler : public HloPassInterface { +class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; ~HloDescheduler() override = default; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 1b9e9bfc77c3ba91e5b878f4aa42d26d8267a49a..214119fba881c4411a262cd4227b5cc49cef0d14 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -146,126 +147,6 @@ ENTRY root { instructions_by_name.at("e"))); } -TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { - // %WhileCond (cond_param: f32[4]) -> pred[] { - // %cond_param = f32[4]{0} parameter(0) - // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) - // ROOT %not-equal-to = pred[] not-equal-to( - // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) - // } - // %WhileBody (body_param: f32[4]) -> f32[4] { - // %body_param = f32[4]{0} parameter(0) - // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // ROOT %subtract = f32[4]{0} subtract( - // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) - // } - // %ListAccountsForSubcomputations () -> f32[2,4] { - // %constant.3 = f32[2,4]{1,0} constant( - // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) - // %transpose = f32[2,4]{1,0} transpose( - // f32[2,4]{1,0} %constant.3), dimensions={0,1} - // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), - // condition=%WhileCond, - // body=%WhileBody - // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} - // ROOT %add = f32[2,4]{1,0} add( - // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) - // } - - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - // transpose(matrix) + bcast(while) - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - HloInstruction* while_loop = - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - // Creates 32 bytes and frees 16 - HloInstruction* bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); - - HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); - // Creates 32 bytes - HloInstruction* transpose = builder.AddInstruction( - HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); - - // Creates 32 bytes and frees 64 - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - schedule.sequence(entry_computation).size()); - SequentialHloOrdering ordering(schedule); - // This schedule is an example of List's greedy heuristics being suboptimal. - // The while_loop is more expensive than transpose, so it would have been - // better to schedule it first, instead of during the busy time. - EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); - EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); - - tensorflow::gtl::FlatMap memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The output buffer is aliased, - // so we don't double count. - EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto builder = HloComputation::Builder(TestName()); const auto TUPLE_SIZE = 1; @@ -409,7 +290,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { EXPECT_EQ(module->entry_computation()->instruction_count(), schedule.sequence(module->entry_computation()).size()); - tensorflow::gtl::FlatMap memory_by_computation; + absl::flat_hash_map memory_by_computation; memory_by_computation[cond_computation] = 17; memory_by_computation[body_computation] = 16; std::unique_ptr points_to_analysis = diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index b3949f3a6d7176950c61cafb0830d1175f17758d..93e04eb3db47ba3dadfbd412733997b92c07da92 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -144,7 +146,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { @@ -285,8 +288,8 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - tensorflow::gtl::FlatMap computation_map; - tensorflow::gtl::FlatMap to_proto_id; + absl::flat_hash_map computation_map; + absl::flat_hash_map to_proto_id; std::vector> computations; HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { @@ -327,10 +330,10 @@ StatusOr> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. - tensorflow::gtl::FlatSet computation_names; - tensorflow::gtl::FlatSet instruction_names; - tensorflow::gtl::FlatSet computation_ids; - tensorflow::gtl::FlatSet instruction_ids; + absl::flat_hash_set computation_names; + absl::flat_hash_set instruction_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3bc2d13781aa72738d695e37a02983ee82c6037d..735804e827afd77e2b7f2a4a7d490ee6f5ee7b4f 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -63,6 +63,7 @@ class HloModule { // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. explicit HloModule(const string& name, const HloModuleConfig& config); + virtual ~HloModule() {} // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -87,6 +88,7 @@ class HloModule { const std::unordered_map& replacements); const string& name() const { return name_; } + void set_name(string name) { name_ = std::move(name); } // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; @@ -255,7 +257,7 @@ class HloModule { std::unique_ptr computation, bool is_entry, bool uniquify_identifiers); - const string name_; + string name_; HloModuleConfig config_; HloComputation* entry_computation_ = nullptr; std::vector> computations_; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index f7be5cae2239e81d9aa1f5fb811a37c6086b028f..31d26cc51e8217234526bbfeb83510aadf2c27b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -50,9 +50,7 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_root = while_body_comp->root_instruction(); if (!ShapeUtil::IsTuple(xla_while->shape()) || - while_body_root->opcode() != HloOpcode::kTuple || - while_body_comp->HasSideEffect() || - xla_while->while_condition()->HasSideEffect()) { + while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 12ca2340a6ccaa50780e81168c755c1fec3aa1be..d472211d2af6e4b583d3815146ba8cee5c8e7495 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -28,7 +28,7 @@ namespace xla { // Sweeps through live instructions which cross computation boundaries (kWhile), // and removes code at dead shape indices. // -class HloModuleDCE : public HloPassInterface { +class HloModuleDCE : public HloModulePass { public: ~HloModuleDCE() override {} absl::string_view name() const override { return "hlo-module-dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index d025edbb9c4f5484458a6a96328a0ee5720b17f7..bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -372,26 +372,64 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { auto module = ParseHloString(R"( HloModule OutfeedLoop WhileBody { - loop_var.1 = (s32[]) parameter(0) + body_param = (s32[]) parameter(0) token = token[] after-all() constant.2 = s32[] constant(2) outfeed_tuple = (s32[]) outfeed(constant.2, token) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) ROOT tuple = (s32[]) tuple(add) } WhileCondition { - loop_var.2 = (s32[]) parameter(0) - get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 constant.2 = s32[] constant(10) ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) } ENTRY SimpleLoop { constant.3 = s32[] constant(0) tuple.1 = (s32[]) tuple(constant.3) - ROOT while = (s32[]) while(tuple.1), condition=WhileCondition, + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 })") .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9b56ef4643f2ca88e56456ae6c990161adb5085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(absl::string_view name, + std::unique_ptr module) + : name_(name) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector> HloModuleGroup::ConsumeModules() { + std::vector> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + HloModuleGroup(absl::string_view name, std::unique_ptr module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span> modules); + + // Returns the modules contained in the group. + const std::vector& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector> ConsumeModules(); + + string name() const { return name_; } + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs); + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 9c01862a4b7024826c3f701b795819abe945d07f..b4aac4c8076cb69647d42c6243bc969d06d0709e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { } /* static */ StatusOr> -HloModuleGroupMetadata::Build(const std::vector& modules) { +HloModuleGroupMetadata::Build(absl::Span modules) { auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); @@ -392,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); - companion_set->insert(instruction1); - companion_set->insert(instruction2); + companion_set->push_back(instruction1); + companion_set->push_back(instruction2); companion_set_index_[instruction1] = companion_sets_.size() - 1; companion_set_index_[instruction2] = companion_sets_.size() - 1; } else if (!ContainsKey(companion_set_index_, instruction1)) { - companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_sets_[companion_set_index_[instruction2]]->push_back( + instruction1); companion_set_index_[instruction1] = companion_set_index_[instruction2]; } else if (!ContainsKey(companion_set_index_, instruction2)) { - companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_sets_[companion_set_index_[instruction1]]->push_back( + instruction2); companion_set_index_[instruction2] = companion_set_index_[instruction1]; } else if (companion_set_index_[instruction1] != companion_set_index_[instruction2]) { - companion_sets_[companion_set_index_[instruction1]]->insert( - Companions(instruction2).begin(), Companions(instruction2).end()); + // At any point while building the companion sets, each instruction belongs + // to at most 1 companion set, so the union of two companion sets is + // concatenating two disjoint sets. + absl::c_copy(Companions(instruction2), + std::back_inserter( + *companion_sets_[companion_set_index_[instruction1]])); int64 index_to_remove = companion_set_index_[instruction2]; for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 768b0c7eb3695715de5cef7dad1ed5a110561605..928df0f5a7444ad877961a5de970c752e1d024da 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -102,14 +102,14 @@ class HloModuleGroupMetadata { HloInstruction* recv_done = nullptr; }; - explicit HloModuleGroupMetadata(const std::vector& modules) - : modules_(modules) {} + explicit HloModuleGroupMetadata(absl::Span modules) + : modules_(modules.begin(), modules.end()) {} ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. static StatusOr> Build( - const std::vector& modules); + absl::Span modules); // Returns true if the instruction is one of the 4 channel instructions (Send, // Recv, SendDone, RecvDone). @@ -169,14 +169,14 @@ class HloModuleGroupMetadata { // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. - const std::unordered_set& Companions( + const std::vector& Companions( const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } // Returns the companion set at the given index. - const std::unordered_set& companion_set(int64 index) const { + const std::vector& companion_set(int64 index) const { CHECK_LT(index, companion_sets_.size()); return *companion_sets_[index]; } @@ -187,7 +187,7 @@ class HloModuleGroupMetadata { } // Returns the list of all companion sets in the HLO module group. - const std::vector>>& + const std::vector>>& companion_sets() const { return companion_sets_; } @@ -247,37 +247,36 @@ class HloModuleGroupMetadata { void DumpCollectedStats() const; // List of all companion instructions sets in the module. - std::vector>> - companion_sets_; + std::vector>> companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + absl::flat_hash_map companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). - tensorflow::gtl::FlatMap + absl::flat_hash_map tracked_instructions_; // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of // communicating instructions within the proper called computation(s). - tensorflow::gtl::FlatMap> + absl::flat_hash_map> tracked_instructions_comms_; // All channels in the module. std::vector channels_; // Map from channel ids to the index in channels_. - tensorflow::gtl::FlatMap channel_id_map_; + absl::flat_hash_map channel_id_map_; // Map from all-reduce ids to the all reduce instructions. - tensorflow::gtl::FlatMap> all_reduce_map_; + absl::flat_hash_map> all_reduce_map_; // The maximum channel id used in the module group. int64 max_channel_id_ = -1; // The modules that this metadata was built from. - const std::vector& modules_; + const std::vector modules_; - tensorflow::gtl::FlatMap> + absl::flat_hash_map> points_to_analyses_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7b12cb72b8df4610b964fb842da78e160d22d9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + HloModuleGroup group(TestName(), std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + std::vector> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +// Tests that the order of companion instructions in the companion set doesn't +// change across runs. +TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) { + // A simple while loop template for core i sending to core i+1. + constexpr char text[] = R"( +HloModule module_%d + +while_cond { + ROOT p = pred[] constant(true) +} + +while_body { + param = s32[] parameter(0) + token.s = token[] after-all() + token.r = token[] after-all() + send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d + send-done = token[] send-done(send), channel_id=%d + recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d + ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d +} + +ENTRY entry { + while_init = s32[] constant(1) + ROOT while = s32[] while(while_init), condition=while_cond, body=while_body +} +)"; + + // Try creating the module and the metadata kTrialCount times and check the + // companion instructions remain in the same order. + const int64 kTrialCount = 5; + const int64 kDeviceCount = 10; + std::vector companion_order; + + for (int64 t = 0; t < kTrialCount; ++t) { + HloModuleGroup group(TestName()); + for (int64 i = 0; i < kDeviceCount; ++i) { + const int64 send_channel = i; + const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(absl::StrFormat(text, i, send_channel, send_channel, + recv_channel, recv_channel))); + group.push_back(std::move(module)); + } + ASSERT_EQ(group.modules().size(), kDeviceCount); + + TF_ASSERT_OK_AND_ASSIGN(auto metadata, + HloModuleGroupMetadata::Build(group.modules())); + ASSERT_EQ(metadata->companion_sets().size(), 1); + + std::vector module_ids; + for (HloInstruction* companion : *metadata->companion_sets()[0]) { + module_ids.push_back(metadata->GetModuleId(companion->GetModule())); + } + + if (t == 0) { + companion_order = module_ids; + } else { + EXPECT_TRUE(absl::c_equal(companion_order, module_ids)); + } + } +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index d83ee714905252e36f38438e81002a4d6ba7dafa..fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { std::vector predecessors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique predecessors list; if the predecessors is a companion // instruction, also add companion instructions; if the predecessors is a @@ -119,7 +119,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { std::vector successors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet unique; + absl::flat_hash_set unique; // Adds to the unique successors list; if the successor is a companion // instruction, also add companion instructions; if the successor is a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index 309c23045d1e0dd91e2f245d00c51d9bf9961bf5..f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -87,7 +87,7 @@ class HloModuleGroupUtil { // * visit_state: map from each instruction to its visit state. // * visit_function: function called when each instruction group. // * root: the root instruction of the traversal. - using VisitStates = tensorflow::gtl::FlatMap; + using VisitStates = absl::flat_hash_map; Status VisitTopologicalOrder(VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root); diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 62439434208b047d5ddb6110c61c3f8ed1768cb6..39f38b417ab0e8b54864176d8d1e0ad1a422eca6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" - #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 2d4e38589fe4693e73c46d6c82e51cb0a8388f85..4551a1c2e259b06818f913cb6a9e782436b7e594 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) { } StatusOr StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap({ + static auto* opcode_map = new absl::flat_hash_map({ #define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ {opcode_name, HloOpcode::enum_name}, HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index f1dc08bafa17a2dd68a7e922d4b84658bbf2589c..23d41d91d6969ddf9062507e926ae39c1e1315d4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { - // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' - // is live into the module. + // Entry parameter should always be defined before other instructions. const HloModule* module = b.defining_instruction()->parent()->parent(); if (b.defining_instruction()->parent() == module->entry_computation() && b.defining_instruction()->opcode() == HloOpcode::kParameter) { return false; } + if (a.defining_instruction()->parent() == module->entry_computation() && + a.defining_instruction()->opcode() == HloOpcode::kParameter) { + return true; + } + // Phi values require special handling. Because XLA does not have a phi // instruction, the definition instruction of the phis values are // placeholders: either the subcomputation parameter (body or condition) or @@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back(absl::StrFormat(" %s", predecessor->name())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b0361c3f02922bcaa14d52ad3b240701080f9b58..66313492eb2dd10ac9a6000639ddb8991b367c0f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering { // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> predecessors_; }; @@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering { // this map so more than one instruction may have the same position // value. This is not a problem because ExecutesBefore also verifies // instructions are in the same computation. - tensorflow::gtl::FlatMap order_position_; + absl::flat_hash_map order_position_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 00970bcda34209d33867099d0bcf3b2902d52ae8..b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } +TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { + // Entry parameter should always be defined before other instruction. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param), + dataflow->GetValueDefinedAt(constant))); + EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(param))); +} + TEST_F(HloOrderingTest, ValuesInWhileComputations) { // Tests the ordering of values (defined by dataflow analysis) in the body and // condition of a while instruction. HLO code: diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 11caa89c545e8fbfad96a9ab8e448a68a565e423..128113f7a53c5fa1463aa9e7a2891ff36ca46930 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -64,14 +64,11 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(absl::string_view str, const HloModuleConfig& config) - : lexer_(str), config_(config) {} + explicit HloParser(absl::string_view str) : lexer_(str) {} - // Runs the parser. Returns false if an error occurred. - bool Run(); - - // Returns the parsed HloModule. - std::unique_ptr ConsumeHloModule() { return std::move(module_); } + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns false if an error occurred. + Status Run(HloModule* module); // Returns the error information. string GetError() const { return StrJoin(error_, "\n"); } @@ -82,28 +79,37 @@ class HloParser { StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); - // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); - private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair* FindInstruction( const string& name, const optional& shape = nullopt); + // Parse a single instruction worth of text. + bool ParseSingleInstruction(HloModule* module); + // ParseXXX returns false if an error occurred. - bool ParseHloModule(); - bool ParseComputations(); + bool ParseHloModule(HloModule* module); + + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -168,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -234,6 +241,7 @@ class HloParser { bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); + bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -284,25 +292,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map> computation_pool_; - HloLexer lexer_; - std::unique_ptr module_; std::vector> computations_; - const HloModuleConfig config_; std::vector error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function*(string, - const optional&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -349,24 +379,50 @@ bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run() { +Status HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(); + if (lexer_.GetKind() == TokKind::kw_HloModule) { + // This means that the text contains a full HLO module. + if (!ParseHloModule(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a HloModule:\n%s", + GetError()); + } + return Status::OK(); + } + // This means that the text is a single HLO instruction. + if (!ParseSingleInstruction(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a single " + "HloInstruction:\n%s", + GetError()); + } + return Status::OK(); } std::pair* HloParser::FindInstruction( const string& name, const optional& shape) { - std::pair* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } // ::= 'HloModule' name computations -bool HloParser::ParseHloModule() { +bool HloParser::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { return TokenError("expects HloModule"); } @@ -385,22 +441,20 @@ bool HloParser::ParseHloModule() { return false; } - module_ = absl::make_unique(name, config_); - - if (!ParseComputations()) { + module->set_name(name); + if (!ParseComputations(module)) { return false; } if (is_scheduled.has_value() && *is_scheduled) { - TF_CHECK_OK( - module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); } return true; } // computations ::= (computation)+ -bool HloParser::ParseComputations() { +bool HloParser::ParseComputations(HloModule* module) { HloComputation* entry_computation = nullptr; do { if (!ParseComputation(&entry_computation)) { @@ -416,21 +470,20 @@ bool HloParser::ParseComputations() { if ((entry_computation != nullptr && computations_[i].get() != entry_computation) || (entry_computation == nullptr && i != computations_.size() - 1)) { - module_->AddEmbeddedComputation(std::move(computations_[i])); + module->AddEmbeddedComputation(std::move(computations_[i])); continue; } - auto computation = - module_->AddEntryComputation(std::move(computations_[i])); + auto computation = module->AddEntryComputation(std::move(computations_[i])); // The parameters and result layouts were set to default layout. Here we // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } @@ -447,7 +500,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -455,40 +507,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -497,43 +530,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -544,6 +596,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map attrs; optional sharding; @@ -774,8 +839,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - auto loc = lexer_.GetLoc(); - optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; @@ -783,20 +846,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - switch (operands.size()) { - case 1: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), /*keys=*/operands[0])); - break; - case 2: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], /*values=*/operands[1])); - break; - default: - return Error(loc, StrCat("expects either 1 or 2 operands, but has ", - operands.size(), " operands")); - } + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], + /*values=*/absl::Span(operands).subspan(1))); break; } case HloOpcode::kTuple: { @@ -1274,21 +1327,65 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional opaque; optional window; optional dnums; optional feature_group_count; + optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( - shape, operands, *custom_call_target)); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } @@ -2151,7 +2248,20 @@ bool HloParser::ParseOperands(std::vector* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair* instruction = FindInstruction(name, shape); @@ -2304,9 +2414,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast*>(attr_out_ptr)->emplace(result); return true; @@ -2446,6 +2564,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2738,6 +2865,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2745,23 +2889,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, @@ -3139,7 +3275,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3209,91 +3345,96 @@ StatusOr HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +bool HloParser::ParseSingleInstruction(HloModule* module) { + if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { + LOG(FATAL) << "Parser state is not clean. Please do not call any other " + "methods before calling ParseSingleInstruction."; + } + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional& shape) -> std::pair* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; - // Prime the lexer. - lexer_.Lex(); - // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return false; + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return false; + } } - return Status::OK(); + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); + } + return true; } } // namespace StatusOr> ParseHloString( absl::string_view str, const HloModuleConfig& config) { - HloParser parser(str, config); - if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } - return parser.ConsumeHloModule(); + auto module = absl::make_unique(/*name=*/"_", config); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } StatusOr> ParseHloString(absl::string_view str) { - HloModuleConfig config; - return ParseHloString(str, config); + auto module = absl::make_unique(/*name=*/"_", HloModuleConfig()); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name) { - HloModuleConfig config; - HloParser parser(str, config); - auto builder = absl::make_unique(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr computation = builder->Build(); - auto module = absl::make_unique(string(name), config); - module->AddEntryComputation(std::move(computation)); - return std::move(module); +Status ParseHloString(absl::string_view str, HloModule* module) { + TF_RET_CHECK(module->computation_count() == 0); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module)); + return Status::OK(); } StatusOr ParseSharding(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseShardingOnly(); } StatusOr ParseWindow(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } StatusOr ParsePaddingConfig(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParsePaddingConfigOnly(); } diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 1882a184da8f09a9626daf7a2bbc531cb6ba6138..81eeb9f13bf7f06123c0b35e9f3352c197866a7a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -30,18 +30,18 @@ namespace xla { // For details about the syntax accepted by this parser, see // g3doc/hlo_parser.md. -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with the given config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. StatusOr> ParseHloString( absl::string_view str, const HloModuleConfig& config); -// Parses the text for a single HLO operation into an HLO module with a function -// that runs that operation (with the same parameters) as its entry computation. -StatusOr> ParseHloOpToModule( - absl::string_view str, absl::string_view name = "single_op"); +// Given a string in the HloModule::ToString() format, parses the string and +// builds the HloModule in place at the given module pointer. 'module' must +// point to an empty module (no computations). +Status ParseHloString(absl::string_view str, HloModule* module); -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with default config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index cca50fab5444d5e23c02952d56566b643a2192a4..ef2e74588cf1cfa3936564bb10346141833cd52b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -802,6 +802,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -966,6 +1003,21 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} } +)" +}, +// Sort (Key, Value, Value, Value) +{ +"SortManyValues", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values.0 = s32[1024,16]{0,1} parameter(1) + values.1 = u32[1024,16]{0,1} parameter(2) + values.2 = f32[1024,16]{0,1} parameter(3) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} +} + )" }, // Conditional @@ -1002,6 +1054,18 @@ ENTRY CustomCall { ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" } +)" +}, +// CustomCall with opaque value. +{ +"CustomCallWithOpaque", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque" +} + )" }, // Variables with non-default names @@ -1151,49 +1215,80 @@ ENTRY Sort { // clang-format on } -class HloParserTest : public ::testing::Test, - public ::testing::WithParamInterface { +// The test class for those tests defined above which round-trip through the +// parser and ToString is templatized on two bool parameters: +// +// short_form : used for the "short" test cases which use the ShortParsable +// output form. +// proto_round_trip : whether the module should also be round-tripped through +// HloProto form. This provides much better coverage for the proto +// serialization/deserialization. +// +// The proto_round_trip=true case also technically covers the Parser->ToString +// roundtrip as well, but separating out the Parser->ToString roundtrip as its +// own test provides better isolation and could conceivably catch weirdo bugs +// which are hidden by interaction between the textual and proto roundtripping. +template +class HloParameterizedParserTest + : public ::testing::Test, + public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(string_view s, string_view expected) { - EXPECT_TRUE(absl::StrContains(s, expected)) - << "'" << s << "' does not contain '" << expected << "'"; - } - // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString( - HloPrintOptions().set_print_large_constants(true))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(original)); + if (proto_round_trip) { + TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( + module->ToProto(), module->config())); + } + if (short_form) { + EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable())); + } else { + EXPECT_EQ( + original, + module->ToString(HloPrintOptions().set_print_large_constants(true))); + } } }; -class HloParserShortTest : public HloParserTest { - protected: - void ExpectEqualShort() { - const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, - result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); - } -}; +// These using shenanigans are required because the TEST_P macro doesn't like +// template instantiations which contain commas. +using HloParserTestLong = HloParameterizedParserTest; +using HloParserTestLongProto = HloParameterizedParserTest; +using HloParserTestShort = HloParameterizedParserTest; +using HloParserTestShortProto = HloParameterizedParserTest; -TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserTestLong, Run) { ExpectEqual(); } +TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } +TEST_P(HloParserTestShort, Run) { ExpectEqual(); } +TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, ::testing::ValuesIn(CreateTestCases()), TestDataToString); - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, ::testing::ValuesIn(CreateShortTestCases()), TestDataToString); +class HloParserTest : public ::testing::Test { + protected: + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } +}; + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = ParseHloString(original); @@ -1261,7 +1356,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1720,6 +1815,25 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, SameNameDiffComputations) { + const string original = R"(HloModule same_names: +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT result = f32[] add(p0, p1) +} + +ENTRY ReduceR3ToR2 { + p0 = f32[8,16,256]{2,1,0} parameter(0) + p1 = f32[] constant(0) + ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); + ASSERT_NE(module->entry_computation(), nullptr); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + TEST_F(HloParserTest, ParseSharding) { const string original = "{maximal device=42}"; TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); @@ -1773,27 +1887,142 @@ TEST(HloParserSingleOpTest, SingleOp) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " "f32[2,4]{1,0} %x)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Parameter(0), op::Parameter(1))); } -TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { +TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { + const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; + StatusOr> module = ParseHloString(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("expects '=' in instruction")); +} + +TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr> module = ParseHloOpToModule(text); + StatusOr> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); - EXPECT_THAT( - module.status().ToString(), - ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("Operand had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, SingleOpNoNames) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, CanonicalOp) { + const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, CanonicalOpWithNested) { + const string text = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested) { + const string text = + R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= +{ + %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("does not exist: x")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[], f32[]) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const string text = R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1892,5 +2121,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); +} + +// custom call incompatible shape. + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index f1ad0f9b0148cb3d5f938e7f5d220d6cb82ea98d..fdaac34386c5135d6bbeb372d7a9199344836c8d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -25,15 +26,45 @@ limitations under the License. namespace xla { // Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. class HloPassInterface { public: virtual ~HloPassInterface() = default; virtual absl::string_view name() const = 0; - // Run the pass on the given HLO module. Return whether it modified the + // Run the pass on the given HLO module. Returns whether it modified the // module. virtual StatusOr Run(HloModule* module) = 0; + + // Run the pass on the given HLO module group. Returns whether it modified the + // module group. Ideally, the module group variant would be named "Run" as + // well, but C++ does not handle overloaded virtual methods well. + virtual StatusOr RunOnModuleGroup(HloModuleGroup* module_group) = 0; +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); + changed |= module_changed; + } + return changed; + }; +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override { + return InternalError("Module group pass cannot be run on a module"); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e4ed0de626688c0d836d6bc9c619245db8d61dd..5e004ce78ac1fd6da18ab2a54d23ef27e9586cf6 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -25,112 +26,131 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { -namespace { -using absl::StrAppend; -using absl::StrCat; - -void DumpModuleGraph(const HloModule& module, const string& message) { - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - XLA_VLOG_LINES(3, module.ToString()); +template +Status HloPassPipeline::RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name) { + for (auto& invariant_checker : invariant_checkers_) { + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr changed_status = RunHelper(invariant_checker.get(), hlo); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); + if (!changed_status.ok()) { + VLOG(2) << "Failed invariant check:"; + XLA_VLOG_LINES(2, hlo->ToString()); + return Status(changed_status.status().code(), + absl::StrCat(changed_status.status().error_message(), + "\n\nFailed after ", after_pass_name)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; + } + return Status::OK(); } -void DumpModuleProto(const HloModule& module, const string& dump_to, - const string& pipeline_name, const string& pass_name) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; +template +StatusOr HloPassPipeline::RunPassesInternal( + HloT* hlo, absl::Span passes) { + string last_pass_name = "pipeline-start"; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + bool changed = false; + for (HloPassInterface* pass : passes) { + VLOG(1) << " HLO pass " << pass->name(); + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/pass->name()); + TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + changed |= pass_changed; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); + last_pass_name = string(pass->name()); + } + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/"pipeline-end"); + return changed; +} - const string mod_name = SanitizeFileName( - absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), - pass_number, pipeline_name, pass_name)); +std::vector HloPassPipeline::GetEnabledPasses( + const DebugOptions& debug_options) { + auto repeated_field = debug_options.xla_disable_hlo_passes(); + absl::flat_hash_set disabled_pass_names(repeated_field.begin(), + repeated_field.end()); + if (!disabled_pass_names.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << absl::StrJoin(disabled_pass_names, ", "); + } - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), - dump_to, mod_name)); + std::vector enabled_passes; + for (auto& pass : passes_) { + if (disabled_pass_names.count(string(pass->name())) == 0) { + enabled_passes.push_back(pass.get()); + } + } + return enabled_passes; } -} // namespace -StatusOr HloPassPipeline::Run(HloModule* module) { - run_called_ = true; +void HloPassPipeline::MaybeDumpHlo(const HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + const string& proto_dump_path = + module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!proto_dump_path.empty()) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new absl::flat_hash_map(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string filename = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, name(), after_pass_name)); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( + MakeHloProto(module), proto_dump_path, filename)); + } - VLOG(1) << "Running HLO pass pipeline " << name(); + const string message = + StrCat("after ", after_pass_name, ", before ", before_pass_name); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + VLOG(3) << "HLO " << message << ":"; + XLA_VLOG_LINES(3, module.ToString()); +} - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), - repeated_field.end()); - if (!disabled_passes.empty()) { - VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << absl::StrJoin(disabled_passes, ", "); +void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const HloModule* module : module_group.modules()) { + MaybeDumpHlo(*module, after_pass_name, before_pass_name); } +} - auto run_invariant_checkers = [this, - module](const string& message) -> Status { - for (auto& invariant_checker : invariant_checkers_) { - VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr changed_status = invariant_checker->Run(module); - VLOG(1) << " Invariant checker done " << invariant_checker->name(); - if (!changed_status.ok()) { - VLOG(2) << "Module failed invariant check:"; - XLA_VLOG_LINES(2, module->ToString()); - return Status(changed_status.status().code(), - StrCat(changed_status.status().error_message(), - "\n\nFailed ", message)); - } - TF_RET_CHECK(!changed_status.ValueOrDie()) - << "invariant checkers must not change the graph"; - } - return Status::OK(); - }; +StatusOr HloPassPipeline::Run(HloModule* module) { + run_called_ = true; - string prefix = StrCat(name(), ": pipeline start"); - bool changed = false; - string message; - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("before running pipeline: ", name()))); - const string xla_dump_per_pass_hlo_proto_to = - module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - "pipeline_start"); - } + VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " + << name(); - for (auto& pass : passes_) { - if (disabled_passes.count(string(pass->name())) > 0) { - VLOG(1) << " Skipping HLO pass " << pass->name() - << ", disabled by --xla_disable_hlo_passes"; - continue; - } + return RunPassesInternal(module, + GetEnabledPasses(module->config().debug_options())); +} - VLOG(1) << " HLO pass " << pass->name(); +StatusOr HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) { + run_called_ = true; - // Emit label containing: "after foo-pass, before bar-pass". - message.clear(); - StrAppend(&message, prefix, ", before ", pass->name()); - DumpModuleGraph(*module, message); - - TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("after running pass: ", pass->name()))); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - string(pass->name())); - } + VLOG(1) << "Running HLO pass pipeline on module group " + << module_group->name() << ": " << name(); - changed |= changed_this_pass; - prefix.clear(); - StrAppend(&prefix, name(), ": after ", pass->name()); + if (module_group->modules().empty()) { + VLOG(1) << "Module group is empty. Nothing to do."; + return false; } - DumpModuleGraph(*module, prefix + ", pipeline end"); - return changed; + + return RunPassesInternal( + module_group, + GetEnabledPasses(module_group->module(0).config().debug_options())); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 1d41a4dac1d8e2f392be0e4e856ead36a5b71d68..09e7033ea4ed88849d2f3665d04f74f3f388b3f5 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface { return *pass; } - // Run all passes on the given HLO module. StatusOr Run(HloModule* module) override; + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override; private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. + void MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO. HloT can be either HloModule + // or HloModuleGroup. + template + Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template + StatusOr RunPassesInternal(HloT* hlo, + absl::Span passes); + + // Helpers which run the given passes on the given HLO construct. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { + return pass->Run(module); + } + static StatusOr RunHelper(HloPassInterface* pass, + HloModuleGroup* module_group) { + return pass->RunOnModuleGroup(module_group); + } + const string name_; std::vector> passes_; std::vector> invariant_checkers_; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee8cb12b231718e09f6ac0d05d7a6887f4c4d746 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloPassPipelineTest : public HloVerifiedTestBase { + protected: + StatusOr ParseModuleGroup( + absl::Span hlo_strings) { + HloModuleGroup group(TestName()); + for (const string& hlo_string : hlo_strings) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + group.push_back(std::move(module)); + } + return std::move(group); + } +}; + +// A module pass which renames instructions named 'foo' to 'bar'. +class FooToBarModulePass : public HloModulePass { + absl::string_view name() const override { return "foo2bar"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "foo") { + instruction->SetAndSanitizeName("bar"); + changed = true; + } + } + } + return changed; + } +}; + +// A module group pass which renames instructions named 'baz' to 'qux'. +class BazToQuxModuleGroupPass : public HloModuleGroupPass { + absl::string_view name() const override { return "baz2qux"; } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "baz") { + instruction->SetAndSanitizeName("qux"); + changed = true; + } + } + } + } + return changed; + } +}; + +// An invariant checker pass which returns an error if there exists an +// instruction named 'bar'. +class BarBlowerUpper : public HloModulePass { + absl::string_view name() const override { return "bar-blower-upper"; } + + StatusOr Run(HloModule* module) override { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "bar") { + return InternalError("Module has instruction named bar"); + } + } + } + return false; + } +}; + +TEST_F(HloPassPipelineTest, ModulePassChanged) { + // Test an HLO module pass which changes a module. + const string module_str = R"( +HloModule ModulePassChanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "foo"); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(root->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, ModulePassUnchanged) { + // Test an HLO module pass which does not change a module. + const string module_str = R"( +HloModule ModulePassUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT blahblah = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(HloPassPipelineTest, MixedPipeline) { + // Test a pipeline with both a module pass and a module group pass. + const string module_0_str = R"( +HloModule MixedPipeline.1 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT baz = f32[] multiply(a, b) +} +)"; + const string module_1_str = R"( +HloModule MixedPipeline.0 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group, + ParseModuleGroup({module_0_str, module_1_str})); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + pipeline.AddPass(); + + HloInstruction* root0 = + module_group.module(0).entry_computation()->root_instruction(); + HloInstruction* root1 = + module_group.module(1).entry_computation()->root_instruction(); + EXPECT_EQ(root0->name(), "baz"); + EXPECT_EQ(root1->name(), "foo"); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pipeline.RunOnModuleGroup(&module_group)); + EXPECT_TRUE(changed); + + EXPECT_EQ(root0->name(), "qux"); + EXPECT_EQ(root1->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, InvariantChecker) { + const string module_str = R"( +HloModule InvariantChecker + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + { + // Run a pipeline with just the invariant checker. It should not fail + // because there is no 'bar' instruction in the module. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); + } + + { + // Run a pipeline which renames 'foo' to 'bar' then an invariant checker + // which fails if there is an instruction named 'bar'. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after foo2bar")); + } + + { + // Run the invariant-checker only pipeline again. It should fail this time. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after pipeline-start")); + } +} + +TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) { + // Running a module group pass on a module should produce an error. + const string module_str = R"( +HloModule ModuleGroupPassOnModule + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Module group pass cannot be run on a module")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index b9c0b0c4ee1957fce48641230cef6391bcc9180e..026a0e8fba2a197a2825a7892f07353682140fa8 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include @@ -36,6 +37,17 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 3d9c375cd5d26f92cf8316f78789daf4fc08c927..1db82dd6fcaa5d7fe7d65894c1021105f0b26266 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,12 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Create an HLO state from serialized representation. In addition to +// creating the proto with HloModule::CreateFromProto(...) it also +// uses HloVerifier to ensure basic invariants are held. +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config); + // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. StatusOr> EntryComputationParameterShapes( diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc243d955e136ccdf097c8155a115845..2d5197be9e6f69f698729e06b7506a5bc6260bcd 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index b66a2aa4bd2b00a88cdbfa6b41c9123bb370aa87..5a5f01f8fd647c74217c80ce4a7633b8957e335f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -154,7 +154,7 @@ class HloReachabilityMap { // Dense assignment from HloInstruction* to number. These numbers index // into the bit_vectors_ vector and into the bits within a BitVector. - tensorflow::gtl::FlatMap indices_; + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bd6dd79b679729adb6691ef809b19f06c6d5dd05..5ac43808ee2945eaa5003baad24d5d331419db83 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -75,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) { // cache before, and eventually calling the IsRematerializable() API. bool CanBeRematerialized( const HloInstruction* instruction, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { auto it = remat_able->find(instruction); if (it != remat_able->end()) { return it->second; @@ -268,7 +270,7 @@ class InstructionList { Item* first_; // Item for each instruction. - tensorflow::gtl::FlatMap item_map_; + absl::flat_hash_map item_map_; }; // Return the items which use the given LogicalBuffer. Sets @@ -503,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker( PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); - tensorflow::gtl::FlatMap + absl::flat_hash_map logical_buffer_to_buffer_id; for (auto* item = instruction_list_.first(); item != nullptr; @@ -854,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction, Item* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, int64 memory_limit_bytes, - tensorflow::gtl::FlatMap* remat_able) { + absl::flat_hash_map* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -980,10 +982,10 @@ StatusOr HloRematerialization::RematerializeComputation( // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the // blacklist. - tensorflow::gtl::FlatSet remat_move_instructions; + absl::flat_hash_set remat_move_instructions; // The map from instructions to their rematerializable status. - tensorflow::gtl::FlatMap remat_able; + absl::flat_hash_map remat_able; // The peak memory of the computation at any point in the instruction // sequence. @@ -1198,6 +1200,12 @@ StatusOr HloRematerialization::Run(HloModule* module) { << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + // Initialize pass object state. + computation_peak_memory_.clear(); + rematerialized_computations_.clear(); + instructions_rematerialized_ = 0; + net_instructions_added_ = 0; + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index e2aaf18b3e482bbf777c594c7f5a22832be2ac17..70d83c04f07ca7fd0139f586869e8fe688f958f4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -33,7 +35,7 @@ namespace xla { // CSE will undo the effects of this optimization and should not be run after // this pass. In general, this pass should be run very late, immediately before // code generation. -class HloRematerialization : public HloPassInterface { +class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; @@ -115,14 +117,13 @@ class HloRematerialization : public HloPassInterface { // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization // occurs. - tensorflow::gtl::FlatMap - computation_peak_memory_; + absl::flat_hash_map computation_peak_memory_; std::unique_ptr points_to_analysis_; // Set of computations which have had rematerialization // applied. Rematerialization is only applied once per computation. - tensorflow::gtl::FlatSet rematerialized_computations_; + absl::flat_hash_set rematerialized_computations_; // Count of the total instructions rematerialized. int64 instructions_rematerialized_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 3fc5dbeb02a26134a7f255fa0b6ebda1dc41ce4d..9972eb20774550817143cb27dd94667364cf68ec 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -30,7 +32,7 @@ namespace xla { /* static */ StatusOr HloSchedule::CreateFromProto( const HloModule* module, const HloScheduleProto& proto) { - tensorflow::gtl::FlatMap id_to_computation; + absl::flat_hash_map id_to_computation; for (const HloComputation* computation : module->computations()) { id_to_computation[computation->unique_id()] = computation; } @@ -44,7 +46,7 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - tensorflow::gtl::FlatMap id_to_instruction; + absl::flat_hash_map id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -112,13 +114,13 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - tensorflow::gtl::FlatMap id_to_instruction; + absl::flat_hash_map id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; + absl::flat_hash_set ids_in_schedule; for (int id : sequences_.at(computation->unique_id()).ids()) { InsertOrDie(&ids_in_schedule, id); } @@ -126,15 +128,13 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's // operands that have not yet been scheduled. When this value reaches zero, // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; + absl::flat_hash_map unscheduled_operand_count; // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. @@ -211,15 +211,15 @@ Status HloSchedule::Update() { if (sequences_.size() > nonfusion_computations.size()) { // Schedule contains some computations which have been removed from the // HloModule. Remove them from the schedule as well. - tensorflow::gtl::FlatSet nonfusion_computations_ids; + absl::flat_hash_set nonfusion_computations_ids; for (const HloComputation* computation : nonfusion_computations) { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { if (nonfusion_computations_ids.count(it->first) == 0) { - it = sequences_.erase(it); + sequences_.erase(it++); } else { - it++; + ++it; } } } @@ -254,7 +254,7 @@ Status HloSchedule::Verify() const { // For each computation verify the set of instructions is the same and that // each dependency and control edge is honored. for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; + absl::flat_hash_map instruction_position; int pos = 0; for (const HloInstruction* instruction : sequence(computation).instructions()) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 270fe6039f0afd119c76086de9a0596e0560e93e..0a714101ee587aa847fa674bbde5586287c51f33 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -103,8 +104,7 @@ class HloSchedule { // Returns a map from HloComputation unique ID to instruction sequence. The // map contains all sequences in the schedule. - const tensorflow::gtl::FlatMap& sequences() - const { + const absl::flat_hash_map& sequences() const { return sequences_; } @@ -148,7 +148,7 @@ class HloSchedule { // A map from computation unique ID to instruction sequence. Unique IDs are // used rather than HloComputation pointers because HLO pointers are not // unique across HLO transformations because pointers may be recycled. - tensorflow::gtl::FlatMap sequences_; + absl::flat_hash_map sequences_; }; std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index de7e6b53d4d2aa88e2213248370b4da82bdeadeb..188f4acc7945f3ec98065eae5a87a41c39730432 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -369,10 +370,28 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } + + TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL) + << "Maximal sharding is expected to have single device assignment, but " + << proto.tile_assignment_devices().size() << " has provided."; + + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector devices(proto.tile_assignment_devices().begin(), diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index d1cf644f8273e632e2952cca0da749616e9b6233..fa34bddde1a47b520f7f96361d155e4017e44e60 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -22,7 +22,7 @@ namespace xla { // Unify subcomputations of a `HloModule`: if any computations are equal, choose // one arbitrarily to use and delete the others. -class HloSubcomputationUnification : public HloPassInterface { +class HloSubcomputationUnification : public HloModulePass { public: absl::string_view name() const override { return "subcomputation-unification"; diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 773fc7d22537ab81d945c197b713b00d322a7f24..59594ab2f0f70a206c73e998dbfa69c2c5c7ba43 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); + case HloOpcode::kDomain: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; @@ -166,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses( positions_.insert(positions_.end(), positions.begin(), positions.end()); // Gather the computation roots at which this value appears. - tensorflow::gtl::FlatSet root_positions; + absl::flat_hash_set root_positions; for (const HloPosition& position : positions_) { if (position.instruction == position.instruction->parent()->root_instruction()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 50f39cbcb55e29a2654ed8c745ea24ee2e0ab899..a1f668921d7d9a1db8e8e528fd02446838d6dd02 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -23,10 +24,18 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { +static Status CheckOperandCount(const HloInstruction* hlo, int expected) { + if (hlo->operand_count() != expected) { + return InternalError("Expected %d operands for %s instruction: %s", + expected, HloOpcodeString(hlo->opcode()), + hlo->ToString()); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -58,12 +67,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -74,6 +85,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -82,6 +94,7 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( @@ -92,6 +105,7 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -118,11 +132,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { } Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -156,6 +172,7 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -166,6 +183,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -192,10 +210,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - if (instruction->operand_count() != 2) { - return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString()); - } + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); @@ -244,29 +259,38 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()), - StringifyShape(sort->operand(1)->shape())); + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); + } + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); return CheckShape(constant, constant->literal().shape()); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); const int64 rank = ShapeUtil::Rank(iota->shape()); if (rank == 0) { @@ -281,6 +305,7 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), @@ -288,6 +313,12 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (reduce->operand_count() % 2 != 0) { + return InternalError( + "Expected an even number of operands for %s instruction: %s", + HloOpcodeString(reduce->opcode()), reduce->ToString()); + } + std::vector operand_shapes; for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); @@ -298,10 +329,12 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); @@ -313,14 +346,16 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_dimension < ShapeUtil::Rank(operand_shape); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)) + TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + (broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension))) << broadcast->ToString() << " operand shape " << operand_shape; } return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == @@ -329,12 +364,14 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } @@ -359,9 +396,30 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + TF_RET_CHECK(custom_call != nullptr); + if (custom_call->layout_constrained()) { + // If the layout is constrained, verify all the respective shapes have + // layouts and that the constrained operand shapes match the shapes of the + // operands. + TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape())); + TF_RET_CHECK(custom_call->operand_count() == + custom_call->operand_shapes_with_layout().size()); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + const Shape& operand_shape_with_layout = + custom_call->operand_shapes_with_layout()[i]; + TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), + operand_shape_with_layout)); + TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -369,6 +427,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), @@ -377,6 +436,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -406,6 +466,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -415,6 +476,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -425,6 +487,7 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -444,6 +507,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -458,12 +522,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { + TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -471,10 +537,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -482,6 +550,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -491,6 +560,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -501,6 +571,7 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -512,6 +583,7 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -548,6 +620,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -579,6 +652,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { + TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -587,6 +661,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -674,12 +749,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -687,6 +764,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -763,7 +841,177 @@ Status VerifyHloStructure(HloModule* module) { return Status::OK(); } -Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape())); + } + } + return Status::OK(); +} + +// Verifies that entry computation layout matches characteristics of +// entry computation. +Status CheckEntryComputationLayout(const HloModule& module) { + const HloComputation* computation = module.entry_computation(); + const auto& layout = module.entry_computation_layout(); + + // TODO(117498192): Change into a call to Compatible(...). + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + computation->root_instruction()->shape(), + layout.result_layout().shape())) { + return InternalError( + "Shape of the root instruction of entry computation (%s) should be " + "compatible to one specified in module's entry computation layout (%s)", + ShapeUtil::HumanString(computation->root_instruction()->shape()), + ShapeUtil::HumanString(layout.result_layout().shape())); + } + + if (computation->num_parameters() != layout.parameter_count()) { + return InternalError( + "Number of parameters in entry computation layout (%d) must be same " + "as number of parameters of entry computation computation (%d)", + layout.parameter_count(), computation->num_parameters()); + } + + for (int i = 0; i < computation->num_parameters(); ++i) { + if (!ShapeUtil::Compatible(computation->parameter_instruction(i)->shape(), + layout.parameter_shape(i))) { + return InternalError( + "Shape of the entry computation parameter %d is %s should be " + "compatible to the one specified in module's entry computation " + "layout %s", + i, + ShapeUtil::HumanString( + computation->parameter_instruction(i)->shape()), + ShapeUtil::HumanString(layout.parameter_shape(i))); + } + } + + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString(), instr2->ToString()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + absl::flat_hash_map host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %d is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + +// CHECKs various invariants of a fusion instruction. +Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { @@ -866,50 +1114,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } } + TF_RET_CHECK(fusion->called_computations() == + absl::Span( + {fusion->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << fusion->ToString() << " fusion->fused_instructions_computation(): " + << fusion->fused_instructions_computation()->ToString() + << " fusion->called_computations(): " + << ComputationsToString(fusion->called_computations()); + + for (const auto& fused : fusion->fused_instructions()) { + TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << fusion->parent(); + } + // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. return Status::OK(); } -Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - if (while_cond->num_parameters() != 1) { - return FailedPrecondition( - "While condition must have exactly 1 parameter; had %d : %s", - while_cond->num_parameters(), while_cond->ToString()); - } - if (while_body->num_parameters() != 1) { - return FailedPrecondition( - "While body must have exactly 1 parameter; had %d : %s", - while_body->num_parameters(), while_body->ToString()); - } - if (instruction->operand_count() != 1) { - return FailedPrecondition( - "While loop must have exactly one operand; had %d : %s", - instruction->operand_count(), instruction->ToString()); - } - return Status::OK(); -} - -Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { - if (instruction->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %d", - instruction->true_computation()->name(), instruction->ToString(), - instruction->true_computation()->num_parameters()); - } - if (instruction->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %d", - instruction->false_computation()->name(), instruction->ToString(), - instruction->false_computation()->num_parameters()); - } - return Status::OK(); -} - -Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { +// Checks that the non-scalar operand shapes are compatible to the output +// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); @@ -926,201 +1156,161 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } -namespace { +// Visitor which verifies various fields on the HLO instruction. This class does +// not check result shape as that is checked in the ShapeVerifier. +class InstructionVerifier : public DfsHloVisitorWithDefault { + public: + explicit InstructionVerifier(std::function + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} -// Returns true if the given Shape has a TOKEN shape as any subshape. -bool ShapeContainsToken(const Shape& shape) { - bool contains_token = false; - ShapeUtil::ForEachSubshape( - shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { - contains_token = true; - } - }); - return contains_token; -} + Status DefaultAction(HloInstruction*) override { return Status::OK(); } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape())); - } + Status HandleFusion(HloInstruction* fusion) override { + return CheckFusionInstruction(fusion); } - return Status::OK(); -} -// Checks if the given two instructions share the same channel id. -Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%d), %s (%d)", - instr1->ToString(), instr1->channel_id(), instr2->ToString(), - instr2->channel_id()); + Status HandleBroadcast(HloInstruction* broadcast) override { + // If you see this failure then someone has confused the difference + // between the HLO broadcast op, and the UserComputation broadcast + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // or ComputationLowerer::Visit() + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(broadcast->operand(0)->shape())) + << "Broadcast HLO (" << broadcast->ToShortString() + << ") has invalid number of dimensions: " + << broadcast->dimensions().size() + << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + return Status::OK(); } - return Status::OK(); -} -// Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their -// 'done' variant. -Status CheckSameIsHostTransfer(const HloInstruction* instr1, - const HloInstruction* instr2) { - const HloSendRecvInstruction* send_recv1 = - DynCast(instr1); - const HloSendRecvInstruction* send_recv2 = - DynCast(instr2); - TF_RET_CHECK(send_recv1 != nullptr); - TF_RET_CHECK(send_recv2 != nullptr); - if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( - "Expected instructions to have the same is-host-transfer property: " - "%s, " - "%s ", - instr1->ToString(), instr2->ToString()); + Status HandleWhile(HloInstruction* xla_while) override { + auto* while_cond = xla_while->while_condition(); + auto* while_body = xla_while->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); + } + if (xla_while->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %d : %s", + xla_while->operand_count(), xla_while->ToString()); + } + return Status::OK(); } - return Status::OK(); -} -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - tensorflow::gtl::FlatMap host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %d is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - sendrecv->channel_id(), sendrecv->ToString(), - it_inserted.first->second->ToString()); - } + Status HandleConditional(HloInstruction* conditional) override { + if (conditional->true_computation()->num_parameters() != 1) { + return FailedPrecondition( + "True computation %s of %s must have 1 parameter insted of %d", + conditional->true_computation()->name(), conditional->ToString(), + conditional->true_computation()->num_parameters()); + } + if (conditional->false_computation()->num_parameters() != 1) { + return FailedPrecondition( + "False computation %s of %s must have 1 parameter insted of %d", + conditional->false_computation()->name(), conditional->ToString(), + conditional->false_computation()->num_parameters()); } + return Status::OK(); + } + + Status HandleElementwiseUnary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleElementwiseBinary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + Status HandleGetTupleElement(HloInstruction* gte) override { + TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); return Status::OK(); - }; + } - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. - for (const HloComputation* computation : module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - switch (instruction->opcode()) { - case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); - break; - } - case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); - break; + Status HandleTranspose(HloInstruction* transpose) override { + const Shape& shape = transpose->shape(); + const HloInstruction* operand = transpose->operand(0); + TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size()); + TF_RET_CHECK(shape.dimensions().size() == + transpose->operand(0)->shape().dimensions().size()); + TF_RET_CHECK(std::equal( + operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(transpose->dimensions(), shape.dimensions()).begin())) + << "shape: " << shape << ", operand->shape(): " << shape + << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ") + << "}"; + return Status::OK(); + } + + Status Preprocess(HloInstruction* instruction) override { + auto previous = instructions_by_name_.find(instruction->name()); + TF_RET_CHECK(previous == instructions_by_name_.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << instruction->parent()->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions_by_name_[instruction->name()] = instruction; + return Status::OK(); + } + + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); } - case HloOpcode::kSendDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); - break; - case HloOpcode::kRecvDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); - break; - default: - break; } } + + return Status::OK(); } - return Status::OK(); -} + + private: + absl::flat_hash_map instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; +}; } // namespace StatusOr HloVerifier::Run(HloModule* module) { + TF_RET_CHECK(!module->name().empty()); TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - tensorflow::gtl::FlatMap instructions; - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation); - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK(instruction->called_computations() == - absl::Span( - {instruction->fused_instructions_computation()})) - << "Fusion HLO calls computations other than the " - "fused_instructions_computation: " - << instruction->ToString() - << " instruction->fused_instructions_computation(): " - << instruction->fused_instructions_computation()->ToString() - << " instruction->called_computations(): " - << ComputationsToString(instruction->called_computations()); - - for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == - instruction->fused_instructions_computation()) - << "Fused HLO was missing a parent: " << fused->ToString() - << " parent: " << fused->parent() - << " computation: " << computation; - } - } else if (instruction->opcode() == HloOpcode::kBroadcast) { - // If you see this failure then someone has confused the difference - // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I - // or ComputationLowerer::Visit() - TF_RET_CHECK(instruction->dimensions().size() == - ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO (" << instruction->ToShortString() - << ") has invalid number of dimensions: " - << instruction->dimensions().size() - << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); - } else if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); - } else if (instruction->opcode() != - HloOpcode::kRng /* Rng operands are always scalar. */ - && instruction->IsElementwise()) { - TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); - } - - auto previous = instructions.find(instruction->name()); - TF_RET_CHECK(previous == instructions.end()) - << "HLO has name that is not unique within module:\n" - << instruction->ToString() - << " in computation: " << computation->name() - << "\nPrevious HLO with same name:\n" - << previous->second->ToString() - << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction; - } - std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); + TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } + TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module)); TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); // If the module has a schedule, it must be valid. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 42e3027bf14a827bd0a791510c2d9c107d989ab9..cb49cb95ba8949b84f57d985bdb07a3177edbc5a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -151,15 +151,21 @@ class ShapeVerifier : public DfsHloVisitor { // HLO pass that verifies invariants of HLO instructions for each computation in // the module. -class HloVerifier : public HloPassInterface { +class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -172,22 +178,15 @@ class HloVerifier : public HloPassInterface { StatusOr Run(HloModule* module) override; private: - // CHECKs various invariants of a fusion instruction. - Status CheckFusionInstruction(HloInstruction* fusion) const; - - Status CheckWhileInstruction(HloInstruction* instruction); - - Status CheckConditionalInstruction(HloInstruction* instruction); - - // Checks that the non-scalar operand shapes are compatible to the output - // shape, i.e., that there are no implicit broadcasts of size-one dimensions. - Status CheckElementwiseInstruction(HloInstruction* instruction); - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This is a factory function because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c72ceb209437116a898d027f4d2c657..afe01e5487c3225815e01343d86e9fe894c2cde8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index 85bb4a8b2450a48d461f1d84e0609a38a6818d9c..9c48b7db613b049536c76237b4cfebbbc47448f3 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -25,7 +25,7 @@ namespace xla { // Pass which replaces all implicit broadcasts with their equivalent sequence of // explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloPassInterface { +class ImplicitBroadcastRemover : public HloModulePass { public: ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 06f0e1ed25e71659a61e6de8a84e52cf70064eae..1ebb3319779c00fd4afe90606bf336e16349429d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -23,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( absl::InlinedVector stack; enum DfsState { kDiscovered, kVisited }; - gtl::FlatMap dfs_state_map; + absl::flat_hash_map dfs_state_map; stack.push_back(root); InsertOrDie(&dfs_state_map, root, kDiscovered); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index df9cbab915cc037cec682238886fb524eaeb2c90..e5aa67fd850de647652d66017223e19fb92cc068 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -360,13 +360,13 @@ class IndexedArrayAnalysis { std::vector> owned_tensors_; std::vector owned_literals_; - tensorflow::gtl::FlatMap cache_; + absl::flat_hash_map cache_; }; // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to // unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloPassInterface { +class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: absl::string_view name() const override; StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 3fdc2cee9aad0fe70f66920f757ee5c52bba711f..69a4c160ee5c4539272c3085338dc6de1b9347ff 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -188,13 +189,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_duplicate) { + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map, bool>* + result_cache) { if (consumer == producer) { return true; } if (!consumer->IsFusible()) { return false; } + auto cache_it = result_cache->find(std::make_pair(producer, consumer)); + if (cache_it != result_cache->end()) { + return cache_it->second; + } + bool result = true; for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter @@ -202,20 +210,23 @@ bool InstructionFusion::CanFuseOnAllPaths( if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_duplicate.count(consumer_operand) > 0 || - !ShouldFuse(consumer, i)) { - return false; + if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + result = false; + break; } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { - return false; + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse, + result_cache)) { + result = false; + break; } } - return true; + result_cache->emplace(std::make_pair(producer, consumer), result); + return result; } InstructionFusion::HloInstructionSet @@ -231,6 +242,8 @@ InstructionFusion::ComputeGloballyUnfusible( // fusing operations that require duplication later depending on // is_expensive_(). HloInstructionSet do_not_duplicate; + absl::flat_hash_map, bool> + can_fuse_on_all_paths_result_cache; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { if (do_not_duplicate.count(producer) > 0) { @@ -286,7 +299,8 @@ InstructionFusion::ComputeGloballyUnfusible( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusible() && - CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate, + &can_fuse_on_all_paths_result_cache)) { continue; } do_not_duplicate.insert(producer); @@ -417,7 +431,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { private: std::vector post_order_; - tensorflow::gtl::FlatMap post_order_index_; + absl::flat_hash_map post_order_index_; }; } // namespace diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c1fde8ecfc04792c6c17ebd83190486ef720175a..f14c6675208c72112aea0179c238b58709d625b5 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,3 +1,4 @@ +#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -24,39 +26,12 @@ limitations under the License. namespace xla { -// A queue interface that allows implementations to choose fusion candidates in -// custom order. -class FusionQueue { - public: - FusionQueue() = default; - virtual ~FusionQueue() = default; - - // Dequeues the next fusion candidates: a consumer and the list of producers - // as operand indices. - virtual std::pair> - DequeueNextInstructionAndOperandsToFuseInOrder() = 0; - - // A callback passed to the queue implementation right before the producer is - // fused into the consumer. - virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} - - // A callback passed to the queue implementation right after the fusion is - // created. Note that original_producer could have been destroyed. - virtual void OnFusingInstruction(HloInstruction* fusion, - HloInstruction* original_producer, - HloInstruction* original_consumer) {} - - // A callback passed to the queue implementation to notify the removal of an - // instruction. - virtual void RemoveInstruction(HloInstruction* instruction) = 0; -}; - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in // code generation. Derived classes define ShouldFuse method to select which // instructions to fuse. -class InstructionFusion : public HloPassInterface { +class InstructionFusion : public HloModulePass { public: explicit InstructionFusion( std::function is_expensive, @@ -151,8 +126,15 @@ class InstructionFusion : public HloPassInterface { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_fuse); + // + // A map from to a bool is required as the result cache + // to store and query the results of calls to this function, in order to avoid + // repeated computations. + bool CanFuseOnAllPaths( + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map, bool>* + result_cache); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 146c9052f10cca8b199a480491d9a672d8bebdff..1484e14df10d94841c5a2e849761779f5800392d 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -45,8 +45,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index bb69cb9c47ff2c7de8d13832c4b8e6216c62da73..7c79eb7d791bc9a0743605d3171ff69c6ef41d58 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 082bf8bffed484244139e79f4d3fe30ca091d8ac..2cf5fc94aca98335f19a9156f037e1adce5a6ed0 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints( return Status::OK(); } +namespace { + +bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + return custom_call != nullptr && custom_call->layout_constrained(); +} + +} // namespace + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { - Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted // instruction. @@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - shape_with_layout = ¶meter_layout.shape(); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + parameter_layout.shape(), instruction)); } } - } - if (shape_with_layout != nullptr) { + } else if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(*shape_with_layout, instruction)); - } - - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv) { + constraints->SetInstructionLayout(custom_call->shape(), custom_call)); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } else if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); @@ -498,6 +511,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -605,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); - } else if (instruction->opcode() == HloOpcode::kCustomCall) { - if (!CustomCallRequiresMajorFirstLayout(instruction)) { - continue; - } - // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( - instruction->shape().element_type(), - AsInt64Slice(instruction->shape().dimensions())); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction)); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const Shape& operand_shape = instruction->operand(i)->shape(); - // Opaque operands don't get a layout constraint. - if (ShapeUtil::IsOpaque(operand_shape)) { - continue; - } - - Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i)); - } } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -660,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call, return Status::OK(); } -// Custom calls have fixed input and output layouts. -Status CheckCustomCallLayout(HloInstruction* custom_call) { - for (const HloInstruction* operand : custom_call->operands()) { - TF_RET_CHECK( - ShapeUtil::IsOpaque(operand->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); +// Operands of layout-constrained custom calls must match the expected +// constrained layouts. +Status CheckCustomCallLayout(HloInstruction* instruction) { + if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); + } } - TF_RET_CHECK( - ShapeUtil::IsOpaque(custom_call->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -776,21 +782,27 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << " instruction: " << instruction->ToString(); if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. + // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); ++i) { + const Shape& target_shape = + ShapeUtil::GetSubshape(shape_with_layout, {i}); + const Shape& instr_shape = + ShapeUtil::GetSubshape(instruction->shape(), {i}); HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - SetupCopiedInstruction(*instruction, gte, {i}); - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); + HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); + + if (ShapeUtil::Equal(target_shape, instr_shape)) { + // Shapes and layouts are equal, no need to copy. + element_copies.push_back(gte); + } else { + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each element. + TF_ASSIGN_OR_RETURN(HloInstruction * element_copy, + CreateCopyWithNewLayout(target_shape, gte)); + element_copies.push_back(element_copy); + } } // Gather element copies into a tuple with a new Tuple instruction. HloInstruction* tuple_copy = instruction->parent()->AddInstruction( @@ -910,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - if (CustomCallRequiresMajorFirstLayout(instruction)) { - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); - } + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -958,10 +968,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), - channel_layout_constraints_(channel_constraints) { + channel_layout_constraints_(channel_constraints), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { if (channel_layout_constraints_ != nullptr) { // Save a copy of the input ChannelLayoutConstraints so that we can reset it // if we have to undo previous operations (ClearPreviousPassSideEffects()). @@ -982,7 +997,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. // @@ -1060,7 +1075,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(user)) { + !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); } @@ -1512,19 +1527,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1540,11 +1542,11 @@ Status LayoutAssignment::CalculateComputationLayout( Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidentally use the existing layout. + // by the LayoutAssignment pass, except for those on parameters, the + // computation result, and a couple special cases. The former two are + // specified in computation_layout. Clearing the layouts here avoids hiding + // potential bugs in the layout assignment pass that may accidentally use the + // existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction @@ -1553,7 +1555,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } - if (instruction->opcode() != HloOpcode::kInfeed) { + // Some instructions carry mandatory layouts in their shape. + if (instruction->opcode() != HloOpcode::kInfeed && + !IsLayoutConstrainedCustomCall(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -1654,6 +1658,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1709,6 +1725,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); @@ -1752,6 +1792,18 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(Init()); + // Verify computation layout is sane. + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry_computation_layout_->parameter_count() == + entry->num_parameters()); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + TF_RET_CHECK( + ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i), + entry->parameter_instruction(i)->shape())); + } + TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(), + entry->root_instruction()->shape())); + // We do two passes. The first one we pass a nullptr ComputationLayout to // the RunOnComputation() calls (for non entry computations), and we register // the ComputationLayout which are naturally flowing in DFS fashion to the @@ -1803,7 +1855,8 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } -bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( +/* static */ +bool LayoutAssignment::InstructionCanChangeLayout( const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kAbs: @@ -1822,7 +1875,6 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: - case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1856,6 +1908,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: case HloOpcode::kShiftLeft: @@ -1869,7 +1922,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kTanh: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: - return true; + return false; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1879,6 +1932,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: + case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kFusion: @@ -1893,14 +1947,13 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kReduce: case HloOpcode::kReshape: case HloOpcode::kRng: - case HloOpcode::kScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kAfterAll: case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: - return false; + return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index cf545031d3c7c66770ea4a2392a2df3b8c24cd38..cb56f4cd19ded036ef521a579eb7d6ea7f3b6268 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,8 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -228,8 +228,8 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; - mutable tensorflow::gtl::FlatMap> + mutable absl::flat_hash_map> buffer_sets_cache_; HloComputation* computation_; @@ -281,11 +281,16 @@ class ChannelLayoutConstraints { // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. -class LayoutAssignment : public HloPassInterface { +class LayoutAssignment : public HloModulePass { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. // + // instruction_can_change_layout_func is a function object that determines + // whether an instruction can change layouts. An instruction not being able to + // change layout means that it requires operands with the same rank as the + // output to have the same layout as the output. + // // channel_constraints is both an input and output. Any sends or recvs that // are present in channel_constraints will be laid out as constrained. Any // unconstrained sends or recvs will be laid out as locally optimal and their @@ -295,6 +300,8 @@ class LayoutAssignment : public HloPassInterface { // within any module passed to `Run`. explicit LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function + instruction_can_change_layout_func = InstructionCanChangeLayout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} absl::string_view name() const override { return "layout-assignment"; } @@ -303,10 +310,10 @@ class LayoutAssignment : public HloPassInterface { // (any layouts were changed). StatusOr Run(HloModule* module) override; - // Returns true if the instruction requires that operands with the same rank - // as the output have to have the same layout as the output. - virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( - const HloInstruction* instruction); + // Determines whether an instruction can change layouts. An instruction not + // being able to change layout means that it requires operands with the same + // rank as the output to have the same layout as the output. + static bool InstructionCanChangeLayout(const HloInstruction* instruction); protected: // These methods, invoked by PropagateConstraints, propagate a layout @@ -326,19 +333,6 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); - // By default LayoutAssignment ensures that inputs and outputs of CustomCalls - // have the "major-first" layout (i.e. {n, n-1, ..., 0}). - // - // If this function returns true, LayoutAssignment does not set a layout for - // the given CustomCall. It's up to the backend to set one in - // AddBackendConstraints, if necessary. - // - // Precondition: instruction->opcode() == HloOpcode::kCustomCall. - virtual bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* /*instruction*/) { - return true; - } - // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -504,7 +498,7 @@ class LayoutAssignment : public HloPassInterface { // Every copy added to the module by the layout assignment pass is registered // here. - tensorflow::gtl::FlatSet added_copies_; + absl::flat_hash_set added_copies_; // The pointer to the channel layout constraints passed in with the // constructor. If not nullptr, this is an input/output argument. @@ -521,8 +515,10 @@ class LayoutAssignment : public HloPassInterface { // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. - tensorflow::gtl::FlatSet - unconstrained_layout_instructions_; + absl::flat_hash_set unconstrained_layout_instructions_; + + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 752a61476dd7892a2b7f531c4057015f48fc4758..ff6fdb5e4aab68ad30088112e3716da5ca6c516a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( - entry_computation_layout, /*channel_constraints=*/channel_constraints); + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } @@ -64,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list> minor_to_majors) { + int i = 0; + for (const absl::Span minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -860,6 +882,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + ar.0 = f32[2,2] cross-replica-sum(gte), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=0} + const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] cross-replica-sum(const), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=1} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); +} + TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { const char* module_str = R"( HloModule CopySliceOperandToAvoidImplicitLayoutChange @@ -998,5 +1064,233 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { op::ShapeWithLayout(shape_copy)))); } +TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { + // The first infeed uses layout {0,1}, while the second uses layout {1,0}. + // The mismatch forces a copy of the tuple. The tuple contains a token, so + // layout assignment will fail if it tries to copy the whole tuple. + const char* module_str = R"( + HloModule TupleCopyOnLayoutMismatch + + condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] { + tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.1 = s32[] get-tuple-element(tup.1), index=0 + five = s32[] constant(5) + ROOT lt = pred[] less-than(counter.1, five) + } + + body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { + tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.2 = s32[] get-tuple-element(tup.2), index=0 + tok.2 = token[] get-tuple-element(tup.2), index=1 + + ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2) + next_tok = token[] get-tuple-element(ifeed.2), index=1 + next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0 + + one = s32[] constant(1) + next_counter = s32[] add(counter.2, one) + ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf) + } + + ENTRY main () -> f32[512,1024]{0,1} { + start_tok = token[] after-all() + + ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok) + itok = token[] get-tuple-element(ifeed.3), index=1 + ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0 + + zero = s32[] constant(0) + itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf) + + loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2 + ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2 + } + )"; + + ParseAndVerifyModule(module_str); + ComputationLayout computation_layout( + module().entry_computation()->ComputeProgramShape()); + + // Sanity check to verify that there's a layout mismatch. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + + AssignLayouts(&module(), &computation_layout); + + // Make sure that layout assignment did not magically eliminate the mismatch, + // in which case the test didn't prove anything. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); +} + +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall())); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Tuple()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(module.get(), &computation_layout); + + ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(), + {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 540bbb7c7a74f65ab70f4c6704d6600db2adbb60..6223a34b1258961944a3ac64cd10876d1272c94e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index e5370eca56f2e3a891523ba2b72961d66ec809aa..643ecd0fbaa546c551097b29e74ccd49418e1466 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" -#include +#include #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - tensorflow::gtl::FlatSet - buffers; + std::set buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6..2b46b3c3964b15548dbacc8b0ada0047a0fa85b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace llvm_ir { @@ -77,14 +76,14 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - tensorflow::gtl::FlatMap + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - tensorflow::gtl::FlatMap + absl::flat_hash_map noalias_metadata_; }; diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index eaa09591b72ee5202e0a9d1225d92eca92904adc..ec52a24d782a44fda961feab3230886072e755c7 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() { // so reserve 10% more than the number of instructions to avoid frequent // resizes. logical_buffers_.clear(); - logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10); + logical_buffers_.reserve((module_->instruction_count() * 11) / 10); // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc similarity index 76% rename from tensorflow/compiler/xla/service/inliner.cc rename to tensorflow/compiler/xla/service/map_inliner.cc index 5fd779ebf9b59e34a0844cc3a898bb72ce6044ee..2200ef054a6993fb884751643ab1fb5ab83efe05 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/map_inliner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include @@ -32,10 +32,10 @@ limitations under the License. namespace xla { -// InlinerVisitor traverses the HLO computation and inlines maps. -class InlinerVisitor : public DfsHloVisitorWithDefault { +// MapInlinerVisitor traverses the HLO computation and inlines maps. +class MapInlinerVisitor : public DfsHloVisitorWithDefault { public: - explicit InlinerVisitor(HloComputation* computation) + explicit MapInlinerVisitor(HloComputation* computation) : computation_(computation) {} // Default visitor action is to do nothing and return OK. @@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault { StatusOr Run(HloComputation* computation); private: - // Current HloComputation instance the InlinerVisitor is traversing. + // Current HloComputation instance the MapInlinerVisitor is traversing. HloComputation* computation_; // Whether algebraic simplification has occurred. bool changed_ = false; }; -StatusOr InlinerVisitor::Run(HloComputation* computation) { +StatusOr MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); return changed_; } -Status InlinerVisitor::HandleMap(HloInstruction* map) { +Status MapInlinerVisitor::HandleMap(HloInstruction* map) { HloComputation* function = map->to_apply(); HloInstruction& root = *function->root_instruction(); - // TODO(b/29249531): Add DCE pass to remove unused HloComputations. // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); @@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { return Status::OK(); } -StatusOr Inliner::Run(HloModule* module) { - InlinerVisitor visitor(/*computation=*/nullptr); +StatusOr MapInliner::Run(HloModule* module) { + MapInlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; for (HloComputation* computation : module->computations()) { TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h similarity index 59% rename from tensorflow/compiler/xla/service/inliner.h rename to tensorflow/compiler/xla/service/map_inliner.h index efa8ed3abcc6cd7cd8d31ec2170eae8752988c09..b67911811846e2250068921ef252b7df596d4016 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/map_inliner.h @@ -13,27 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -// A pass which performs inlining. Which can result, for example, in functions -// that were previously being mapped by Map instead directly applied to the -// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloPassInterface { +// A pass which performs map inlining. This replaces kMap instructions with +// their equivalent sequence of array operations. For example: +// map({X, Y}, add) -> add(X, Y)). +class MapInliner : public HloModulePass { public: - ~Inliner() override = default; - absl::string_view name() const override { return "inline"; } + ~MapInliner() override = default; + absl::string_view name() const override { return "map-inline"; } - // Run inlining on the given computation. Returns whether the computation was - // changed. + // Run map inlining on the given computation. Returns whether the computation + // was changed. StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc similarity index 78% rename from tensorflow/compiler/xla/service/inliner_test.cc rename to tensorflow/compiler/xla/service/map_inliner_test.cc index 7e967f035c1054e22d10790188a5a232ca8e751a..84059dd0f71ee8fc0a25703cbab2268d7dc149a8 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include #include @@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` -TEST_F(InlinerTest, MapMax) { +TEST_F(MapInlinerTest, MapMax) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto max_builder = HloComputation::Builder(TestName()); @@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) { } // Test that `constant` function is changed to `broadcast`. -TEST_F(InlinerTest, MapConstant) { +TEST_F(MapInlinerTest, MapConstant) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto const2_builder = HloComputation::Builder(TestName()); @@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_F(InlinerTest, MapSubtractOppositeOrder) { +TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); // Note that the parameter ordinals are in the opposite order to their @@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); @@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(MapInlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index b9ec31c4977be0c31dfff01a0c495902191d7d5b..2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +50,7 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { all_fusion_candidates_.push_back(instruction); std::vector candidates; - tensorflow::gtl::FlatSet candidates_set; + absl::flat_hash_set candidates_set; VLOG(10) << "Looking at instruction: " << instruction->name(); for (auto operand : instruction->operands()) { // Filter out the non-interesting instructions -- they @@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector> new_fusibles; - tensorflow::gtl::FlatSet in_list; + absl::flat_hash_set in_list; auto it = fusion_node.fusibles.begin(); while (it != fusion_node.fusibles.end()) { HloInstruction* instr = it->first; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d2c52651c4f37708906e31b7839d0c9f6f04760e..9508ab2ed1d38ec40983d8892ec8875b848fb21b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -44,7 +45,7 @@ namespace xla { // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with // instruction fusion. -class MultiOutputFusion : public HloPassInterface { +class MultiOutputFusion : public HloModulePass { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} @@ -126,7 +127,7 @@ class MultiOutputFusion : public HloPassInterface { std::vector candidates_; // A map that maps an instruction to the index_. - tensorflow::gtl::FlatMap candidates_index_; + absl::flat_hash_map candidates_index_; // The reachability map of current computation. std::unique_ptr reachability_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index bd8fb17a235ea6eeb0e1809e8cb9ad83145fd8d6..ac2f79674feceff436c0e9c65338967f498e4473 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) { } /*static*/ string NameUniquer::GetSanitizedName(const string& name) { + if (name.empty()) { + return ""; + } string result = name; - CHECK(!result.empty()) << "name should not be empty"; char c = static_cast(result[0]); if (!isalpha(c) && c != '_') { result[0] = '_'; diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 6dd89c240f81c9f0ccac66e50c7f244bfd5429f1..8909d0f4fea801e43ab06a75e8933d24a74146bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -69,7 +69,7 @@ class NameUniquer { int64 next_ = 0; // Set of all the identifiers which has been used. - tensorflow::gtl::FlatSet used_; + absl::flat_hash_set used_; }; // The string to use to separate the prefix of the name from the uniquing @@ -78,7 +78,7 @@ class NameUniquer { // Map from name prefix to the generator data structure which tracks used // identifiers and generates new ones. - tensorflow::gtl::FlatMap generated_names_; + absl::flat_hash_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 4869db79e719fa10d61ad6c6ed41ff70a344f733..380cde0e6a858c7800445be94bb08dc22f3e776a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -17,8 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #include "absl/strings/string_view.h" +#include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -116,15 +120,82 @@ namespace xla { // .WithOperand(1, Op(&c)) // .WithOperand(2, Op(&d)) // + +struct MatchOption { + // If true, actually capture matched item into the user pointer. + bool capture; +}; + template -bool Match(Value* value, const Pattern& pattern) { - return pattern.Match(value); +bool Match(Value* value, const Pattern& pattern, + MatchOption option = {/*.capture=*/true}) { + if (option.capture) { + auto new_option = option; + new_option.capture = false; + if (!pattern.Match(value, new_option)) { + return false; + } + } + return pattern.Match(value, option); } namespace match { namespace detail { +template +class AllOfPattern { + public: + explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + bool Match(Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return std::get(patterns_).Match(item, option) && + MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return true; + } + + std::tuple patterns_; +}; + +} // namespace detail + +// Returns a pattern that represents the conjunction of all input patterns. All +// patterns need to match in order to have the AllOf pattern match. +// +// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is +// not AllOf. We might want to flatten the AllOf type structure if the +// C++ compile error message gets annoying. +template +detail::AllOfPattern::type, Patterns...> AllOf( + const Patterns&... patterns) { + return detail::AllOfPattern::type, + Patterns...>(patterns...); +} + +namespace detail { + template class LayoutPattern; @@ -132,57 +203,61 @@ class LayoutPattern; // nullptr. class LayoutPatternBaseImpl { public: - bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout != nullptr; + } }; // A LayoutPattern implementation that matches only if the layout equals a // Layout proto. -template class LayoutPatternEqualImpl { public: - explicit constexpr LayoutPatternEqualImpl(const Previous& previous, - const ::xla::Layout* layout) - : previous_(previous), layout_(layout) {} + explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout) + : layout_(layout) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return LayoutUtil::Equal(*layout_, *layout); } private: - Previous previous_; const ::xla::Layout* layout_; }; // A LayoutPattern implementation that matches only if the layout has a given // format. -template class LayoutPatternFormatImpl { public: - explicit constexpr LayoutPatternFormatImpl(const Previous& previous, - Format format) - : previous_(previous), format_(format) {} + explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && layout->format() == format_; + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout->format() == format_; } private: - Previous previous_; Format format_; }; // A pattern that matches Layouts. template class LayoutPattern { + private: + template + LayoutPattern> + AppendImpl(NewImpl new_impl) const { + return LayoutPattern>( + AllOf(impl_, std::move(new_impl)), matched_layout_); + } + public: explicit constexpr LayoutPattern(const Impl& impl, LayoutType** matched_layout) : impl_(impl), matched_layout_(matched_layout) {} // Returns true and captures the layout iff it matches the pattern. - bool Match(const ::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(const ::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -191,9 +266,9 @@ class LayoutPattern { } // Returns true and captures the layout iff it matches the pattern. - bool Match(::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -203,24 +278,21 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr LayoutPattern> EqualTo( - const ::xla::Layout* layout) const { - return LayoutPattern>( - LayoutPatternEqualImpl(impl_, layout), matched_layout_); + constexpr auto EqualTo(const ::xla::Layout* layout) const + -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr LayoutPattern> - WithDenseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + constexpr auto WithDenseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + return AppendImpl(LayoutPatternFormatImpl(DENSE)); } // Modifies the pattern to match only if the layout has a sparse format. - constexpr LayoutPattern> - WithSparseFormat() const { - return LayoutPattern>( - LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + constexpr auto WithSparseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { + return AppendImpl(LayoutPatternFormatImpl(SPARSE)); } private: @@ -228,8 +300,72 @@ class LayoutPattern { LayoutType** matched_layout_; }; +template +class AnyOfPattern { + public: + explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + bool Match(Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant()); + } + + private: + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + auto new_option = option; + new_option.capture = false; + // Try to match the sub-pattern without capturing behavior. + if (std::get(patterns_).Match(item, new_option)) { + // Capture the branch. + if (option.capture) { + // TODO(timshen): Currently the behavior can be exponential. Optimize it + // with memoization or recording the matched sub-pattern index, if it + // takes too long to run. + // + // Specifically, the "memoization" approach is to create an empty + // container with the key (pattern, instruction), and value as whether + // matched or not. + // + // Alternatively, we may run the pattern matching with captures off, but + // instead record a "trace" somewhere, indicating how exactly the + // pattern matches the input. For example, the trace information for + // AnyOf will be a runtime number indicate which sub-pattern is matched. + // Then we run another pass to do captures only with the help of the + // trace. + bool ret = std::get(patterns_).Match(item, option); + DCHECK(ret); + } + return true; + } + return MatchImpl(item, option, std::integral_constant()); + } + + template + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant) const { + return false; + } + + std::tuple patterns_; +}; + } // namespace detail +// Returns a pattern that represents the logical disjunction of the input +// patterns. The returned pattern matches from left to right, and stops on the +// first match. +template +detail::AnyOfPattern::type, Patterns...> AnyOf( + const Patterns&... patterns) { + return detail::AnyOfPattern::type, + Patterns...>(patterns...); +} + // Creates a layout pattern that will capture the matched layout in the // argument. inline constexpr detail::LayoutPattern class ShapePatternEqualImpl { public: - explicit constexpr ShapePatternEqualImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Equal(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape is compatible to // a Shape proto. -template class ShapePatternCompatibleImpl { public: - explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Compatible(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape has a given // element type. -template class ShapePatternElementTypeImpl { public: - explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, - PrimitiveType element_type) - : previous_(previous), element_type_(element_type) {} + explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type) + : element_type_(element_type) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && shape->element_type() == element_type_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return shape->element_type() == element_type_; } private: - Previous previous_; PrimitiveType element_type_; }; // A ShapePattern implementation that matches only if the shape is scalar. -template class ShapePatternIsScalarImpl { public: - explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsScalarImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsScalar(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is an array -template class ShapePatternIsArrayImpl { public: - explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsArrayImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsArray(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is a tuple. -template class ShapePatternIsTupleImpl { public: - explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsTupleImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsTuple(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape has a given // rank. -template class ShapePatternRankImpl { public: - explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) - : previous_(previous), rank_(rank) {} + explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Rank(*shape) == rank_; } private: - Previous previous_; int64 rank_; }; // A ShapePattern implementation that matches only if the shape has a layout // that matches a given pattern. -template +template class ShapePatternLayoutImpl { public: explicit constexpr ShapePatternLayoutImpl( - const Previous& previous, const LayoutPattern& layout) - : previous_(previous), layout_(layout) {} + : layout_(layout) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(&shape->layout()); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout(), option); } - bool Match(Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout()); + bool Match(Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout(), option); } private: - Previous previous_; LayoutPattern layout_; }; // A ShapePattern implementation that matches only if the shape has a subshape // that matches a given pattern. -template +template class ShapePatternSubshapeImpl { public: explicit ShapePatternSubshapeImpl( - const Previous& previous, ShapeIndexView index, + ShapeIndexView index, const ShapePattern& subshape) - : previous_(previous), index_(index), subshape_(subshape) {} + : index_(index), subshape_(subshape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); } - bool Match(::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + bool Match(::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), + option); } private: - Previous previous_; ShapeIndexView index_; ShapePattern subshape_; }; @@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl { // A pattern that matches Shapes. template class ShapePattern { + private: + template + ShapePattern> AppendImpl( + NewImpl new_impl) const { + return ShapePattern>( + AllOf(impl_, std::move(new_impl)), matched_shape_); + } + public: explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) : impl_(impl), matched_shape_(matched_shape) {} // Returns true and captures the shape iff it matches the pattern. - bool Match(const ::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -447,9 +564,9 @@ class ShapePattern { } // Returns true and captures the shape iff it matches the pattern. - bool Match(::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -459,108 +576,90 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr ShapePattern> EqualTo( - const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternEqualImpl(impl_, shape), matched_shape_); + constexpr auto EqualTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr ShapePattern> - CompatibleTo(const ::xla::Shape* shape) const { - return ShapePattern>( - ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + constexpr auto CompatibleTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr ShapePattern> - WithElementType(PrimitiveType element_type) const { - return ShapePattern>( - ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + constexpr auto WithElementType(PrimitiveType element_type) const + -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr ShapePattern> IsScalar() - const { - return ShapePattern>( - ShapePatternIsScalarImpl(impl_), matched_shape_); + constexpr auto IsScalar() const + -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr ShapePattern> IsArray() - const { - return ShapePattern>( - ShapePatternIsArrayImpl(impl_), matched_shape_); + constexpr auto IsArray() const + -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr ShapePattern> IsTuple() - const { - return ShapePattern>( - ShapePatternIsTupleImpl(impl_), matched_shape_); + constexpr auto IsTuple() const + -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + return AppendImpl(ShapePatternIsTupleImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr ShapePattern> WithRank( - int64 rank) const { - return ShapePattern>( - ShapePatternRankImpl(impl_, rank), matched_shape_); + constexpr auto WithRank(int64 rank) const + -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template - constexpr ShapePattern> - WithLayout(const LayoutPattern& layout) const { - return ShapePattern>( - ShapePatternLayoutImpl(impl_, layout), - matched_shape_); - } - - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - WithLayoutEqualTo(const ::xla::Layout* layout) const { + auto WithLayout(const LayoutPattern& layout) const + -> decltype(this->AppendImpl( + ShapePatternLayoutImpl(layout))) { + return AppendImpl(ShapePatternLayoutImpl(layout)); + } + + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const + -> decltype(this->WithLayout(Layout().EqualTo(layout))) { return WithLayout(Layout().EqualTo(layout)); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsDenseArray() const { + constexpr auto IsDenseArray() const + -> decltype(this->WithLayout(Layout().WithDenseFormat())) { return WithLayout(Layout().WithDenseFormat()); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl>> - IsSparseArray() const { + constexpr auto IsSparseArray() const + -> decltype(this->WithLayout(Layout().WithSparseFormat())) { return WithLayout(Layout().WithSparseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template + auto WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) + const -> decltype(this->AppendImpl( + ShapePatternSubshapeImpl(index, + subshape))) { + return AppendImpl( + ShapePatternSubshapeImpl(index, subshape)); + } + ShapePattern> - WithSubshape(ShapeIndexView index, - const ShapePattern& subshape) const { - return ShapePattern< - ShapeType, ShapePatternSubshapeImpl>( - ShapePatternSubshapeImpl(impl_, index, - subshape), - matched_shape_); - } - - ShapePattern>> + AllOfPattern>>> WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, ShapePattern( @@ -568,9 +667,12 @@ class ShapePattern { .EqualTo(shape)); } - ShapePattern>> + ShapePattern>>> WithSubshapeCompatibleTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, @@ -611,159 +713,169 @@ class HloInstructionPattern; // instruction is not nullptr. class HloInstructionPatternBaseImpl { public: - bool Match(const ::xla::HloInstruction* inst) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return inst != nullptr; } }; // An HloInstructionPattern implementation that matches only if the instruction // has a given name. -template class HloInstructionPatternNameImpl { public: - explicit HloInstructionPatternNameImpl(const Previous& previous, - absl::string_view name) - : previous_(previous), name_(name) {} + explicit HloInstructionPatternNameImpl(absl::string_view name) + : name_(name) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->name() == name_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->name() == name_; } private: - Previous previous_; absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. -template class HloInstructionPatternOpcodeImpl { public: - explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, - HloOpcode opcode, + explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode, bool invert) - : previous_(previous), opcode_(opcode), invert_(invert) {} + : opcode_(opcode), invert_(invert) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return (invert_ ^ (inst->opcode() == opcode_)); } private: - Previous previous_; HloOpcode opcode_; bool invert_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a shape that matches a given pattern. -template +template class HloInstructionPatternShapeImpl { public: explicit constexpr HloInstructionPatternShapeImpl( - const Previous& previous, const ShapePattern& shape) - : previous_(previous), shape_(shape) {} + const ShapePattern& shape) + : shape_(shape) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(&inst->shape()); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(&inst->shape(), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(inst->mutable_shape(), option); } private: - Previous previous_; ShapePattern shape_; }; // An HloInstructionPattern implementation that matches only if the instruction // has an operand that matches a given pattern. -template +template class HloInstructionPatternOperandImpl { public: explicit constexpr HloInstructionPatternOperandImpl( - const Previous& previous, int64 operand_index, + int64 operand_index, const HloInstructionPattern& operand) - : previous_(previous), operand_index_(operand_index), operand_(operand) {} + : operand_index_(operand_index), operand_(operand) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_)); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_), option); } private: - Previous previous_; int64 operand_index_; HloInstructionPattern operand_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. -template class HloInstructionPatternFusionKindImpl { public: explicit constexpr HloInstructionPatternFusionKindImpl( - const Previous& previous, ::xla::HloInstruction::FusionKind kind) - : previous_(previous), kind_(kind) {} + ::xla::HloInstruction::FusionKind kind) + : kind_(kind) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } private: - Previous previous_; ::xla::HloInstruction::FusionKind kind_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a kGetTupleElement with a particular tuple index. -template class HloInstructionPatternTupleIndexImpl { public: - explicit constexpr HloInstructionPatternTupleIndexImpl( - const Previous& previous, int64 tuple_index) - : previous_(previous), tuple_index_(tuple_index) {} + explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index) + : tuple_index_(tuple_index) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } private: - Previous previous_; int64 tuple_index_; }; +template +class HloPredicatePatternImpl { + public: + explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + + bool Match(const ItemType* item, MatchOption option) const { + return pred_(item); + } + + bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + + private: + Predicate pred_; +}; + +struct PatternFriend; + // A pattern that matches HloInstructions. template class HloInstructionPattern { + private: + template + HloInstructionPattern> + AppendImpl(NewImpl new_impl) const { + return HloInstructionPattern< + HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( + AllOf(impl_, std::move(new_impl)), matched_inst_); + } + public: explicit constexpr HloInstructionPattern(const Impl& impl, HloInstructionType** matched_inst) : impl_(impl), matched_inst_(matched_inst) {} // Returns true and captures the instruction iff it matches the pattern. - bool Match(const ::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -772,9 +884,9 @@ class HloInstructionPattern { } // Returns true and captures the instruction iff it matches the pattern. - bool Match(::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -783,102 +895,87 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - HloInstructionPattern> - WithName(absl::string_view name) const { - return HloInstructionPattern>( - HloInstructionPatternNameImpl(impl_, name), matched_inst_); + auto WithName(absl::string_view name) const + -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - constexpr HloInstructionPattern> - WithOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, false), - matched_inst_); + auto WithOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + false))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - constexpr HloInstructionPattern> - WithoutOpcode(HloOpcode opcode) const { - return HloInstructionPattern>( - HloInstructionPatternOpcodeImpl(impl_, opcode, true), - matched_inst_); + auto WithoutOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + true))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr HloInstructionPattern> - IsConstant() const { + constexpr auto IsConstant() const + -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr HloInstructionPattern> - IsNonConstant() const { + constexpr auto IsNonConstant() const + -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl> - WithShape(const ShapePattern& shape) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl>( - HloInstructionPatternShapeImpl(impl_, - shape), - matched_inst_); + constexpr auto WithShape(const ShapePattern& shape) + const -> decltype(this->AppendImpl( + HloInstructionPatternShapeImpl(shape))) { + return AppendImpl( + HloInstructionPatternShapeImpl(shape)); } // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl> - WithOperand( + constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern& operand) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl>( - HloInstructionPatternOperandImpl( - impl_, operand_index, operand), - matched_inst_); + const HloInstructionPattern& operand) const + -> decltype(this->AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand))) { + return AppendImpl( + HloInstructionPatternOperandImpl( + operand_index, operand)); } // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr HloInstructionPattern> - WithFusionKind(HloInstruction::FusionKind kind) const { - return HloInstructionPattern>( - HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const + -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr HloInstructionPattern> - WithTupleIndex(int64 tuple_index) const { - return HloInstructionPattern>( - HloInstructionPatternTupleIndexImpl(impl_, tuple_index), - matched_inst_); + constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( + this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } private: + template + constexpr auto WithPredicate(Predicate pred) const -> decltype( + this->AppendImpl(HloPredicatePatternImpl( + std::move(pred)))) { + return AppendImpl( + HloPredicatePatternImpl(std::move(pred))); + } + + friend struct PatternFriend; + Impl impl_; HloInstructionType** matched_inst_; }; @@ -1005,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose) .WithOperand(0, std::forward(lhs)) \ .WithOperand(1, std::forward(rhs)); \ } -XLA_BINOP_PATTERN(Add) + +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ + return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ + } \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs))) { \ + return AnyOf(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs)); \ + } +XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(Eq) +XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) XLA_BINOP_PATTERN(Gt) XLA_BINOP_PATTERN(Le) XLA_BINOP_PATTERN(Lt) -XLA_BINOP_PATTERN(Maximum) -XLA_BINOP_PATTERN(Minimum) -XLA_BINOP_PATTERN(Multiply) -XLA_BINOP_PATTERN(Ne) +XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) +XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) +XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) +XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) -XLA_BINOP_PATTERN(And) -XLA_BINOP_PATTERN(Or) +XLA_COMMUTATIVE_BINOP_PATTERN(And) +XLA_COMMUTATIVE_BINOP_PATTERN(Or) XLA_BINOP_PATTERN(ShiftLeft) XLA_BINOP_PATTERN(ShiftRightArithmetic) XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_COMMUTATIVE_BINOP_PATTERN #undef XLA_BINOP_PATTERN // Helpers for ternary instructions. @@ -1070,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN +namespace detail { +struct PatternFriend { + template + static auto ConstantScalar(T constant) -> decltype( + Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate( + std::declval>())) { + std::function pred = + [constant](const HloInstruction* instr) { + const auto& literal = Cast(instr)->literal(); + auto status_or_const = LiteralUtil::CreateR0(constant).Convert( + literal.shape().element_type()); + return status_or_const.ok() && + literal == status_or_const.ConsumeValueOrDie(); + }; + + return Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate(std::move(pred)); + } +}; +} // namespace detail + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1107,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } +template +inline auto ConstantScalar(T constant) + -> decltype(detail::PatternFriend::ConstantScalar(constant)) { + return detail::PatternFriend::ConstantScalar(constant); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index a530581c34bf1d699eae3c53203c197f7943cc53..3ab7b7fd7168d7ddd1470fdb03a04ba7b171fddb 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } +TEST(PatternMatcherTest, AnyOf) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(1)))); + EXPECT_TRUE( + Match(root, match::AnyOf(match::ConstantScalar(1), + match::ConstantScalar(0)))); + EXPECT_FALSE( + Match(root, match::AnyOf(match::ConstantScalar(0), + match::ConstantScalar(2)))); +} + +TEST(PatternMatcherTest, ConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(root, match::ConstantScalar(41))); + EXPECT_FALSE(Match(root, match::ConstantScalar(0))); +} + +TEST(PatternMatcherTest, NoMatchConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_FALSE(Match(root, match::ConstantScalar(42))); +} + +TEST(PatternMatcherTest, MultiplyAnyOrder) { + using match::ConstantScalar; + using match::MultiplyAnyOrder; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + const HloInstruction* instr; + + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); +} + +TEST(PatternMatcherTest, AnyOfShortCircuit) { + using match::AnyOf; + using match::Multiply; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Multiply(&mul, Op(), Op()), Op(&any)))); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(nullptr, any); + } + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf(Op(&any), Multiply(&mul, Op(), Op())))); + EXPECT_NE(nullptr, any); + EXPECT_EQ(nullptr, mul); + } +} + +TEST(PatternMatcherTest, AllOf) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); + auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); + ASSERT_TRUE(Match(root, scalar_pattern)); + ASSERT_TRUE(Match(root, f16_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); + EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE( + Match(root, AllOf(Broadcast(Op()), scalar_pattern))); +} + +TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_FALSE( + Match(root, AllOf(Constant(&constant), Broadcast(Op())))); + EXPECT_EQ(nullptr, constant); + ASSERT_TRUE(Match(root, Constant(&constant))); + EXPECT_NE(nullptr, constant); +} + +TEST(PatternMatcherTest, TestNoCapture) { + using match::Constant; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false})); + EXPECT_EQ(nullptr, constant); +} + +TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + u = f16[] parameter(0) + v = f16[] parameter(1) + ROOT add = f16[] add(u, v) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* addend0 = nullptr; + const HloInstruction* addend1 = nullptr; + const HloInstruction* addend2 = nullptr; + auto add2_pattern = Add(Op(&addend0), Op(&addend1)); + auto add3_pattern = AnyOf( + AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0)); + + ASSERT_TRUE(Match(root, add3_pattern)); + EXPECT_NE(nullptr, addend0); + EXPECT_NE(nullptr, addend1); + EXPECT_EQ(nullptr, addend2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 178a78ede09c34e71566fdee69793fdb1cda6245..c522e7ae23b734090f85d241bf365fccc37f0adb 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware // thread. Because we parallelize a single computation across threads, it - // doesn't make sense to expose these as separate devices, so fix the number - // of devices to one. - device_count = 1; + // doesn't make sense to expose these as separate devices, so by default we + // fix the number of devices to one. However we do let the user override + // this behavior to help run tests on the host that run models in parallel + // across multiple devices. + device_count = legacy_flags::GetDebugOptionsFromFlags() + .xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 256b231e3af43a2ee85c97a5efab1f022d4de4b1..0b4e82e8d606cf2cacfab42d07c2201939d5e10b 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -22,14 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { // HLO pass which inserts reduce-precision instructions into the HLO graph, for // purposes of experimenting with the effects of reduced-precision storage of // intermediate values. -class ReducePrecisionInsertion : public HloPassInterface { +class ReducePrecisionInsertion : public HloModulePass { using InstructionFilterFunction = std::function; public: diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1e86a0823a56a9e52421a5c8bd49e0adb98a2c70..a3db439e34000ef3fcf4b190cb372947e285a64e 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -24,7 +24,7 @@ namespace xla { // This now only moves them outputward across elementwise ops all whose operands // are equivalent Reshapes or Transposes, but in future could potentially move // them inputward also. -class ReshapeMover : public HloPassInterface { +class ReshapeMover : public HloModulePass { public: absl::string_view name() const override { return "reshape-mover"; } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 2f4b2667c405bb23b1c648892c86d337400c14a5..de7aee262e61195b37099fc661a95508d0539e18 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -155,6 +155,53 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } +static StatusOr CheckIndexValidity( + HloComputation* computation, HloInstruction* index, + absl::Span operand_dims, absl::Span window_sizes, + HloModule* module) { + DCHECK_NE(nullptr, module); + DCHECK_EQ(operand_dims.size(), window_sizes.size()); + + // Valid range for the index: [0, operand_dims - window_sizes] + + // Check if the index has any negative values. + TF_ASSIGN_OR_RETURN( + HloInstruction * zero_index, + BroadcastZeros(computation, index->shape().element_type(), + AsInt64Slice(index->shape().dimensions()))); + TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, + MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + std::vector max_valid_index(operand_dims.size()); + for (int i = 0; i < operand_dims.size(); ++i) { + max_valid_index[i] = operand_dims[i] - window_sizes[i]; + } + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + MakeR1ConstantHlo(computation, index->shape().element_type(), + max_valid_index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_index_check, + MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + + // Combine the results of the two checks above. + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index, + MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check)); + + // Reduce the index validity check vector into a scalar predicate. + auto reduction_init = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index_reduced, + MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module)); + + // Return a broadcasted value of the scalar predicate to the same size as the + // window. + return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); +} + // Body of the while loop that performs the scatter operation using other HLOs. static StatusOr> ScatterLoopBody( HloInstruction* scatter, HloInstruction* induction_var, @@ -222,7 +269,16 @@ static StatusOr> ScatterLoopBody( InsertDegenerateDims(update_slice_for_scatter, AsInt64Slice(dim_numbers.inserted_window_dims()))); - // Extact the slice to update from `operand` tensor. + // Note that the following transformation assumes that both DynamicSlice and + // DynamicUpdateSlice follow the same semantics for OOB indices. For example, + // if there are negative indices and DynamicSlice uses "clamping" semantics, + // then the extracted data will be "shifted". Since DynamicUpdateSlice also + // follows the same "clamping" semantics, writing the update will also be + // "shifted" by exactly the same amount. So, this transformation is correct as + // long as the semantics of handling OOB indices remain the same in + // DynamicSlice and DynamicUpdateSlice. + + // Extract the slice to update from `operand` tensor. const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); TF_ASSIGN_OR_RETURN( HloInstruction * operand_slice_to_update, @@ -237,10 +293,24 @@ static StatusOr> ScatterLoopBody( MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, scatter->to_apply())); + TF_ASSIGN_OR_RETURN( + HloInstruction * is_index_valid, + CheckIndexValidity( + operand->parent(), scatter_slice_start, + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()), + scatter->GetModule())); + + // Select the updated operand only if the index is valid. If not, select the + // original value. + TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply, + MakeSelectHlo(is_index_valid, updated_operand_slice, + operand_slice_to_update)); + // Write the updated value of the slice into `operand` tensor. - TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, - MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, - scatter_slice_start)); + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start)); return StatusOr>{ {updated_operand, scatter_indices, updates}}; diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 14f062c89cfd4657097c1a933621a3e945f89c53..559a85dccfef27816e7dbf746fd71c44bbf46f60 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -20,7 +20,7 @@ limitations under the License. namespace xla { -class ScatterExpander : public HloPassInterface { +class ScatterExpander : public HloModulePass { public: absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 922ebdf0e3f0e79674c5a632c873627845a606ec..084df17951b565cbe066d54cb74699bba1ef4bd3 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -207,7 +207,7 @@ Status Service::ValidateResultShape(const Shape& client_shape, StatusOr>> Service::ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors) { + absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -590,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) { + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -812,7 +812,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(module_proto, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, @@ -1160,7 +1160,7 @@ StatusOr> Service::Replicas( return replicas; } -Status Service::MaybeDumpHloModule(const HloModule& module) const { +Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const { const string xla_dump_unoptimized_hlo_proto_to = module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); if (xla_dump_unoptimized_hlo_proto_to.empty()) { @@ -1168,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const { } HloProto proto = MakeHloProto(module); return protobuf_util::DumpProtoToDirectory( - proto, xla_dump_unoptimized_hlo_proto_to, module.name()); + proto, xla_dump_unoptimized_hlo_proto_to, + StrCat(module.name(), ".unoptimized")); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 44c5248b150cff57546d3287869787f37c8975ba..8cf1a7b9f01fbb3572c6849c8b18e14174ced89f 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments); + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -208,7 +208,7 @@ class Service : public ServiceInterface { StatusOr>> ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors); + absl::Span stream_executors) const; // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -271,7 +271,9 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; - Status MaybeDumpHloModule(const HloModule& module) const; + // Dumps the (unoptimized) module given if the corresponding DebugOptions + // field has been set. + Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const; // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 74bdf2a2e3982bc9be29bae037e385fede578ae5..aa49f98bcff1c6759ad049339e0247f45b3c2dad 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers( // Check that dimension numbers are unique. auto dims_unique = [](absl::Span contracting_dims, absl::Span batch_dims) -> bool { - tensorflow::gtl::FlatSet dim_set; + absl::flat_hash_set dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; @@ -1029,17 +1029,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]), - ShapeUtil::HumanString(*operand_shapes[1])); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } @@ -1665,10 +1670,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( "Expected LHS feature dimension (value %d) to match RHS " - "input feature dimension * feature_group_count (value %d); " + "input feature dimension * feature_group_count (value %d * %d = %d); " "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features * feature_group_count, + input_features, kernel_input_features, feature_group_count, + kernel_input_features * feature_group_count, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } @@ -2379,7 +2385,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Transpose dimensions not a permutation of the operand dimensions."); + "Transpose dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 864ed43118cd066f6ce14cd808b873f137b8414a..7b65e8c1c9d2bc730c6c8550e9265b69fdde71cf 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 921a984589bb4fb64058a2a56adfe84fe14af69b..56952e3adae59656605a12fd499162504a2a3379 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - tensorflow::gtl::FlatSet deallocated_ptrs; + absl::flat_hash_set deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index 5d1cd1c4422a10e3b9e6ce6fac2c83594bb58b30..ec09dff9244080d24580cad8ee2359a34a6a4f96 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { // Re-use an existing stream from the pool. stream = std::move(streams_.back()); streams_.pop_back(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream"; + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } else { + VLOG(1) << stream->DebugStreamPointers() + << " stream was not ok, StreamPool deleting"; + stream = nullptr; + } } } diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc index aaf5c37b0d250f78cb57639255ac9b59e1b462f7..92f47579d31303b39f6f3a1859789588b586db87 100644 --- a/tensorflow/compiler/xla/service/stream_pool_test.cc +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) { EXPECT_EQ(stream2_ptr, stream3_ptr); } +TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Return the stream, but hold a handle to it. + se::Stream* stream1_ptr = stream1.get(); + stream1 = nullptr; + + // Now stream1 is back in the pool, force an error on the stream. Here we call + // a method that requires DNN support, which we know the Host platform doesn't + // support. + stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1_ptr->ok()); + + // Borrow stream2. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 3e5aa2db60ee31d9fbccf8f7256b15c1b8465335..f95f982eb89d60884b652cd832dff0363372369c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -23,7 +23,7 @@ namespace xla { // HLO pass that folds transpose operators into Dot operators, where the Dot // operator is implemented by a GEMM kernel that can transpose its inputs. -class TransposeFolding : public HloPassInterface { +class TransposeFolding : public HloModulePass { public: using OperandIndices = std::vector; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 6fed7c76d04ad5d8236fecd07aa27f1eda221ea7..ef4e69180ddf3ce4b050cda54c15566763a4999d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -280,16 +280,6 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { - // A kSlice instruction aliases its operand if the backend lowers it to an - // in-place implementation. - if (slice->IsInPlaceSlice()) { - CreateCopiedPointsToSet(slice, slice->operand(0)); - return Status::OK(); - } - return DefaultAction(slice); -} - Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. @@ -455,15 +445,10 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - // kSlice ops that are lowered to an in-place version are expected to not - // define their output buffer. - if (buffer.instruction()->opcode() != HloOpcode::kSlice || - !buffer.instruction()->IsInPlaceSlice()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString(), buffer.instruction()->name()); - } + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString(), buffer.instruction()->name()); } if (buffer.id() < 0 || @@ -771,6 +756,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index a9e8a51e0923362162c6b8a2e97fc334e56d4329..30c365053c5dac5af3c559f7c92b11d389d7fff8 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -36,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -249,7 +247,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; - Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e9a07b14ed685fa4388aca583395370a60176cca..d9ebebf74ed846aa05326a4df72019ef3e71ad88 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1010,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -1035,7 +1073,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 8c91d6e69de637d58fa2ffc1a32ea65f09d3b6d8..e126a530234c1452bcf91f642f63d4c087935a56 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // A pass which simplifies patterns of Tuple and GetTupleElement instructions in // the module. -class TupleSimplifier : public HloPassInterface { +class TupleSimplifier : public HloModulePass { public: TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 56145822be70f391ac3eaab5fc17db4a80e1b9cc..067cfcc17d65860a249de4d9e31703df12091d3a 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 2dba7d7f7574742a301e3503e353bbe57d72a203..577bad6c7062d2ee40271e407e8eed7655fa13bf 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -50,7 +50,7 @@ namespace xla { // conditions as well. // // TODO(b/79121449): We should also sink broadcasts of constants. -class WhileLoopConstantSinking : public HloPassInterface { +class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index e8fe33e62659ae0fffff1ad46e8ba77f715b76b2..9795b2830b6d9add82b89ac76b5438ddc3d2bfe8 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -15,18 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::InlinedVector; -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet; // function hoists the operands in `unhoisted_invariant_instructions` and moves // them into `hoisted_instructions`. static void CreateLoopInvariantCopy( - FlatMap* hoisted_instructions, - FlatSet* unhoisted_invariant_instructions, + flat_hash_map* hoisted_instructions, + flat_hash_set* unhoisted_invariant_instructions, HloInstruction* while_instr, HloInstruction* to_hoist) { HloComputation* parent_of_while = while_instr->parent(); HloComputation* while_body = while_instr->while_body(); @@ -147,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // Maps instructions in the while body to instructions hoisted outside the // while that compute the same value. - FlatMap hoisted_instructions; + flat_hash_map hoisted_instructions; // Contains instructions that can be legally hoisted, but were deemed to be // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we // hoist an instruction in this set, we move it from // unhoisted_invariant_instructions to hoisted_instructions. - FlatSet unhoisted_invariant_instructions; + flat_hash_set unhoisted_invariant_instructions; // Invariant GTE's axiomatically satisfy the constraints for // unhoisted_invariant_instructions -- they can be legally hoisted, but there diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 2cdf20ce80362c0aeb9d8324573e7e9826cc018c..3031899f71e0fd77f20448d9d7489798af01615c 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that rewrites while loops to hoist loop invariant instructions in // the while body into the computation that contains the while instruction. -class WhileLoopInvariantCodeMotion : public HloPassInterface { +class WhileLoopInvariantCodeMotion : public HloModulePass { public: // If `hoist_constants` is true then constants are always hoisted out of while // loop bodies. Otherwise they are only hoisted out if they enable other diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6a7bfe3f129d97866ccc54897d584fab0f7c683e..630d71e5ca25e9d282ce6283284a32d6f725a193 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -114,7 +115,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } - tensorflow::gtl::FlatSet used_tuple_indices; + absl::flat_hash_set used_tuple_indices; for (HloComputation* comp : {while_body, while_cond}) { // The HLO verifier ensures that while_input's shape matches while_init's // shape, which we verified above is a tuple. @@ -181,7 +182,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.end()); std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); - tensorflow::gtl::FlatMap old_to_new_tuple_idx; + absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; old_to_new_tuple_idx[old_idx] = new_idx; @@ -252,7 +253,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond)); + make_while_computation_replacements(while_cond), /*extras=*/{}); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -265,7 +266,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements)); + while_body->CloneWithReplacements(std::move(while_body_replacements), + /*extras=*/{}); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -404,7 +406,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // build a map from the tuple element index to the constant value. Limit this // to scalar constant values because propagating array constants can regress // performance by forcing us to copy constants. - tensorflow::gtl::FlatMap index_to_constant; + absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { HloInstruction* instr = root_operands[i]; if (instr->opcode() == HloOpcode::kGetTupleElement && diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 78024f14dc89ff40a11bbc3602072fda1fe6f312..0bc5a0107bbcfb3b29a01d593fb79b89a863e49b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -30,7 +30,7 @@ namespace xla { // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // -class WhileLoopSimplifier : public HloPassInterface { +class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} absl::string_view name() const override { return "simplify-while-loops"; } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index a7f0e207eb5a81b04bb28977d6f5e38864ad2d6a..87294120d51d244d9f2649cf95916f022bf829cb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -21,7 +21,7 @@ limitations under the License. // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { -class ZeroSizedHloElimination : public HloPassInterface { +class ZeroSizedHloElimination : public HloModulePass { public: StatusOr Run(HloModule* module) override; absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 96c80fd577e2601c972e374a153f4f0706902ec2..9267de3cfc455ce351ac8cf57f8d45786ea4b5ba 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - CHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + if (shape.dimensions().size() == 1) { + return shape.dimensions()[0]; + } return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); @@ -458,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { @@ -828,7 +832,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString()); } @@ -865,11 +870,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return Status::OK(); } - if (Rank(shape) != shape.dimensions_size()) { - return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%d " - "dimensions_size=%d", - Rank(shape), shape.dimensions_size()); + if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + return InvalidArgument("sparse arrays must have rank > 0"); } for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); @@ -1644,7 +1646,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanString(shape); + out << ShapeUtil::HumanStringWithLayout(shape); return out; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 623ae39de819ebecdc8aee27a2b31176421ef020..73f541d50512523b0c5ddd76a9c0427c39c0824f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -311,7 +312,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). @@ -479,8 +483,7 @@ class ShapeUtil { // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. - // - // DEPRECATED: Use Equal() instead. + ABSL_DEPRECATED("Use Equal() instead.") static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 30e3077edb93e1ac740c1d863aacce975ad4c8a5..8a0ae330420531b833ed670118e6b6b1056bd358 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites" load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -150,11 +154,31 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) +tf_cc_test( + name = "hlo_verified_test_base_test", + srcs = ["hlo_verified_test_base_test.cc"], + deps = [ + ":hlo_test_base", + ":hlo_verified_test_base", + ":test_macros_cpu", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -398,6 +422,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -1797,7 +1822,7 @@ xla_test( tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", @@ -2096,7 +2121,7 @@ tf_cc_test( name = "sample_file_test", srcs = ["sample_file_test.cc"], data = ["isolated_convolution.hlo"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":hlo_test_base", "//tensorflow/compiler/xla:test", @@ -2121,11 +2146,11 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2144,3 +2169,21 @@ xla_test( "//tensorflow/core:lib", ], ) + +tf_cc_test( + name = "multiple_devices_on_host_test", + srcs = ["multiple_devices_on_host_test.cc"], + args = ["--xla_force_host_platform_device_count=4"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/synchronization", + ], +) diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 53f2c3bfbfce9585cb68f103a495ce2f1ad8432e..05d4d04034bf50c8bb840e59b28a590fce048c19 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -3,256 +3,266 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): - """Removes "gpu" from a backend list if CUDA is not enabled. - - This allows us to simply hardcode lists including "gpu" here and in the - BUILD file, without causing failures when CUDA isn't enabled.' - - Args: - backends: A list of backends to filter. - - Returns: - The filtered list of backends. - """ - if cuda_is_configured(): - return backends - else: - return [backend for backend in backends if backend != "gpu"] - - -def xla_test(name, - srcs, - deps, - xla_test_library_deps=[], - backends=[], - blacklisted_backends=[], - args=[], - tags=[], - copts=[], - data=[], - backend_tags={}, - backend_args={}, - **kwargs): - """Generates cc_test targets for the given XLA backends. - - This rule generates a cc_test target for one or more XLA backends and also a - platform-agnostic cc_library rule. The arguments are identical to cc_test with - two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "gpu"), and - 'backend_args'/'backend_tags' specifies backend-specific args parameters to - use when generating the cc_test. - - The name of the cc_tests are the provided name argument with the backend name - appended, and the cc_library target name is the provided name argument with - "_lib" appended. For example, if name parameter is "foo_test", then the cpu - test target will be "foo_test_cpu" and the cc_library target is "foo_lib". - - The cc_library target can be used to link with other plugins outside of - xla_test. - - The build rule also defines a test suite ${name} which includes the tests for - each of the supported backends. - - Each generated cc_test target has a tag indicating which backend the test is - for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These - tags can be used to gather tests for a particular backend into a test_suite. - - Examples: - - # Generates the targets: foo_test_cpu and foo_test_gpu. - xla_test( - name = "foo_test", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) + """Removes "gpu" from a backend list if CUDA is not enabled. - # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu - # includes the additional arg "--special_cpu_flag". - xla_test( - name = "bar_test", - srcs = ["bar_test.cc"], - backends = ["cpu", "gpu"], - backend_args = {"cpu": ["--special_cpu_flag"]} - deps = [...], - ) + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' - The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} - to the value 1 where ${BACKEND} is the uppercase name of the backend. - - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - xla_test_library_deps: If set, the generated test targets will depend on the - respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported values: "cpu", - "gpu". If this list is empty, the test will be generated for all supported - backends. - blacklisted_backends: A list of backends to NOT generate tests for. - args: Test arguments for the target. - tags: Tags for the target. - copts: Additional copts to pass to the build. - data: Additional data to pass to the build. - backend_tags: A dict mapping backend name to list of additional tags to - use for that target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. - **kwargs: Additional keyword arguments to pass to native.cc_test. - """ - test_names = [] - if not backends: - backends = all_backends - - backends = [backend for backend in backends - if backend not in blacklisted_backends] - - native.cc_library( - name="%s_lib" % name, - srcs=srcs, - copts=copts, - testonly=True, - deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], - ) - - for backend in filter_backends(backends): - test_name = "%s_%s" % (name, backend) - this_backend_tags = ["xla_%s" % backend] - this_backend_copts = [] - this_backend_args = backend_args.get(backend, []) - this_backend_data = [] - if backend == "cpu": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "gpu": - backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] - this_backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_deps = [] - backend_deps += plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] - this_backend_tags += plugins[backend]["tags"] - this_backend_args += plugins[backend]["args"] - this_backend_data += plugins[backend]["data"] - else: - fail("Unknown backend %s" % backend) - - if xla_test_library_deps: - for lib_dep in xla_test_library_deps: - backend_deps += ["%s_%s" % (lib_dep, backend)] - - tf_cc_test( - name=test_name, - srcs=srcs, - tags=tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, - args=args + this_backend_args, - deps=deps + backend_deps, - data=data + this_backend_data, - **kwargs) - - test_names.append(test_name) - - native.test_suite(name=name, tests=test_names) - -def xla_test_library(name, - srcs, - hdrs=[], - deps=[], - backends=[]): - """Generates cc_library targets for the given XLA backends. - - This rule forces the sources to be compiled for each backend so that the - backend specific macros could expand correctly. It's useful when test targets - in different directories referring to the same sources but test with different - arguments. - - Examples: - - # Generates the targets: foo_test_library_cpu and foo_test_gpu. - xla_test_library( - name = "foo_test_library", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) - # Then use the xla_test rule to generate test targets: - xla_test( - name = "foo_test", - srcs = [], - backends = ["cpu", "gpu"], - deps = [...], - xla_test_library_deps = [":foo_test_library"], - ) + Args: + backends: A list of backends to filter. - Args: - name: Name of the target. - srcs: Sources for the target. - hdrs: Headers for the target. - deps: Dependencies of the target. - backends: A list of backends to generate libraries for. - Supported values: "cpu", "gpu". If this list is empty, the - library will be generated for all supported backends. - """ - - if not backends: - backends = all_backends - - for backend in filter_backends(backends): - this_backend_copts = [] - if backend in ["cpu", "gpu"]: - backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] - elif backend in plugins: - backend_deps = plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] + Returns: + The filtered list of backends. + """ + if cuda_is_configured(): + return backends else: - fail("Unknown backend %s" % backend) + return [backend for backend in backends if backend != "gpu"] + +def xla_test( + name, + srcs, + deps, + xla_test_library_deps = [], + backends = [], + blacklisted_backends = [], + args = [], + tags = [], + copts = [], + data = [], + backend_tags = {}, + backend_args = {}, + **kwargs): + """Generates cc_test targets for the given XLA backends. + + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and + 'backend_args'/'backend_tags' specifies backend-specific args parameters to + use when generating the cc_test. + + The name of the cc_tests are the provided name argument with the backend name + appended, and the cc_library target name is the provided name argument with + "_lib" appended. For example, if name parameter is "foo_test", then the cpu + test target will be "foo_test_cpu" and the cc_library target is "foo_lib". + + The cc_library target can be used to link with other plugins outside of + xla_test. + + The build rule also defines a test suite ${name} which includes the tests for + each of the supported backends. + + Each generated cc_test target has a tag indicating which backend the test is + for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These + tags can be used to gather tests for a particular backend into a test_suite. + + Examples: + + # Generates the targets: foo_test_cpu and foo_test_gpu. + xla_test( + name = "foo_test", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + + # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu + # includes the additional arg "--special_cpu_flag". + xla_test( + name = "bar_test", + srcs = ["bar_test.cc"], + backends = ["cpu", "gpu"], + backend_args = {"cpu": ["--special_cpu_flag"]} + deps = [...], + ) + + The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} + to the value 1 where ${BACKEND} is the uppercase name of the backend. + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + xla_test_library_deps: If set, the generated test targets will depend on the + respective cc_libraries generated by the xla_test_library rule. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. + blacklisted_backends: A list of backends to NOT generate tests for. + args: Test arguments for the target. + tags: Tags for the target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. + backend_tags: A dict mapping backend name to list of additional tags to + use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. + """ + test_names = [] + if not backends: + backends = all_backends + + backends = [ + backend + for backend in backends + if backend not in blacklisted_backends + ] native.cc_library( - name = "%s_%s" % (name, backend), + name = "%s_lib" % name, srcs = srcs, + copts = copts, testonly = True, - hdrs = hdrs, - copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] - + this_backend_copts, - deps = deps + backend_deps, + deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - -def generate_backend_suites(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - native.test_suite(name="%s_tests" % backend, - tags = ["xla_%s" % backend]) - - -def generate_backend_test_macros(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - manifest = "" - if backend in plugins: - manifest = plugins[backend]["disabled_manifest"] - - native.cc_library( - name="test_macros_%s" % backend, - testonly = True, - srcs = ["test_macros.cc"], - hdrs = ["test_macros.h"], - copts = [ - "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), - "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, - ], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", - ]) + for backend in filter_backends(backends): + test_name = "%s_%s" % (name, backend) + this_backend_tags = ["xla_%s" % backend] + this_backend_copts = [] + this_backend_args = backend_args.get(backend, []) + this_backend_data = [] + if backend == "cpu": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + elif backend == "gpu": + backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] + this_backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_deps = [] + backend_deps += plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] + this_backend_data += plugins[backend]["data"] + else: + fail("Unknown backend %s" % backend) + + if xla_test_library_deps: + for lib_dep in xla_test_library_deps: + backend_deps += ["%s_%s" % (lib_dep, backend)] + + tf_cc_test( + name = test_name, + srcs = srcs, + tags = tags + backend_tags.get(backend, []) + this_backend_tags, + extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + args = args + this_backend_args, + deps = deps + backend_deps, + data = data + this_backend_data, + **kwargs + ) + + test_names.append(test_name) + + native.test_suite(name = name, tests = test_names) + +def xla_test_library( + name, + srcs, + hdrs = [], + deps = [], + backends = []): + """Generates cc_library targets for the given XLA backends. + + This rule forces the sources to be compiled for each backend so that the + backend specific macros could expand correctly. It's useful when test targets + in different directories referring to the same sources but test with different + arguments. + + Examples: + + # Generates the targets: foo_test_library_cpu and foo_test_gpu. + xla_test_library( + name = "foo_test_library", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + # Then use the xla_test rule to generate test targets: + xla_test( + name = "foo_test", + srcs = [], + backends = ["cpu", "gpu"], + deps = [...], + xla_test_library_deps = [":foo_test_library"], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + hdrs: Headers for the target. + deps: Dependencies of the target. + backends: A list of backends to generate libraries for. + Supported values: "cpu", "gpu". If this list is empty, the + library will be generated for all supported backends. + """ + + if not backends: + backends = all_backends + + for backend in filter_backends(backends): + this_backend_copts = [] + if backend in ["cpu", "gpu"]: + backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + else: + fail("Unknown backend %s" % backend) + + native.cc_library( + name = "%s_%s" % (name, backend), + srcs = srcs, + testonly = True, + hdrs = hdrs, + copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + deps = deps + backend_deps, + ) + +def generate_backend_suites(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + native.test_suite( + name = "%s_tests" % backend, + tags = ["xla_%s" % backend, "-broken", "manual"], + ) + +def generate_backend_test_macros(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + manifest = "" + if backend in plugins: + manifest = plugins[backend]["disabled_manifest"] + + native.cc_library( + name = "test_macros_%s" % backend, + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + copts = [ + "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), + "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, + ], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], + ) diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 070b092d18930027e215cb43ff917e36cac99f12..b851db14ec048a20947fb8136a31e457d3922f86 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D(&builder, *alhs); auto rhs = ConstantR4FromArray4D(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index a693fa35954bcb2d95074c94d0aa3eabc1d5fd62..001490c6a8c568656437465054ee4db40d0d8dee 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); @@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0171f515839d556827f0723772214d175939d386..6c0847a875798870b4362a99ac2ab65d99f9f3e6 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -394,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + // Disable algebraic simplification because the pass may replace a dot + // instruction with a layout-changing multiplication instruction. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); } }; @@ -404,31 +408,18 @@ std::vector CreateNoLayoutAssignmentDotTestParameters() { for (bool lhs_row_major : {true, false}) { for (bool rhs_row_major : {true, false}) { for (bool has_addend : {true, false}) { + // The addend needs to be row major to match the result of the dot. params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } if (n != 1) { params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } } } } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 9c94acb437e9fc948a4255f7112e2e7a40cfa5fb..4d4b676a538947c8dd92a7e34db72e45766cae2c 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -764,8 +764,10 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } -// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. -XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { +// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend +// should not generate layout changing elementwise operations. +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { const string hlo_text = R"( HloModule Cluster @@ -794,6 +796,7 @@ ENTRY main { LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), result)); } +#endif class FusionClientLibraryTest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39..7ab2ecda58666acd7e9b8587d200a902b75822f3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104d656946d45008adec9ea3960984545..217428befa474448cf2dcbae2eb6cb5b0e61d44c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); ~HloTestBase() override {} diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 8f86c528d0f346b0264948d592660911880f96d1..8bd0a729b77f3ec14204952cb0062103c823883e 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -21,64 +21,68 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} - -HloVerifiedTestBase::~HloVerifiedTestBase() { - // We can't call the ASSERT or EXPECT test macros in destructors, so we - // perform HLO verification in TearDown, and use the CHECK here to ensure - // users don't accidentally override the verification. - CHECK(tear_down_called_) - << "TearDown was never called; subclasses of HloVerifiedTestBase that " - << "override TearDown must call the superclass TearDown."; -} - -void HloVerifiedTestBase::TearDown() { - EXPECT_FALSE(tear_down_called_) - << "TearDown called more than once; it should be called exactly once."; - tear_down_called_ = true; - if (module_) { - VerifyModule(module_.get()); +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); } - for (int i = 0; i < modules_.size(); ++i) { - VerifyModule(modules_.at(i).get()); - } - HloTestBase::TearDown(); + return verifier_.Run(this).status(); } -void HloVerifiedTestBase::VerifyModule(HloModule* module) { - xla::StatusOr mutated = verifier().Run(module); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; } } +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), + verifier_layout_sensitive_(layout_sensitive), + allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} + HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = HloTestBase::CreateNewModule(); + module_ = CreateNewVerifiedModule(TestName()); } return *module_; } HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(HloTestBase::CreateNewModule()); + modules_.emplace_back(CreateNewVerifiedModule(name)); return modules_.back().get(); } void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); - VerifyModule(module_.get()); + module_ = CreateNewVerifiedModule(TestName()); + TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); + module_->VerifyOrAddFailure("after parsing"); } + +StatusOr> +HloVerifiedTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) { + auto module = CreateNewVerifiedModule(TestName()); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + +std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 8fbc4fa753ebf0c02b44ce10edf9251d28113f98..388a99bb36408665edbc20ade6c6a733d64db88d 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -20,53 +20,84 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { -// A base class for HLO tests that stores a default HloModule, and automatically -// performs verification on that module on tear-down. +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + +// A base class for HLO tests that stores a default VerifiedHloModule. class HloVerifiedTestBase : public HloTestBase { protected: - explicit HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - ~HloVerifiedTestBase() override; + HloVerifiedTestBase(bool layout_sensitive = false, + bool allow_mixed_precision = false); // Constructs a default shape verifier. std::unique_ptr MakeShapeVerifier(); - // Performs verification on the default HloModule returned by module(). - // Automatically called by the testing framework for each test. - // - // REQUIRED: subclasses that override TearDown() must call this explicitly. - void TearDown() override; - // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule& module(); + + ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule* CreateNewModule(const string& name = TestName()); - private: - void VerifyModule(HloModule* module); + // Creates and returns a verified HLO module with the given name. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + private: // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. // // Lazily populated. Access via module(). - std::unique_ptr module_; + std::unique_ptr module_; + // Populated by calls to CreateNewModule. - std::vector> modules_; + std::vector> modules_; - bool tear_down_called_ = false; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c0263e811f94c90a69a460525ffa0c65127ebb5 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// This class includes unit tests which are expected to fail because invalid HLO +// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to +// include the necessary gunit parts to test this test machinery (needs the +// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the +// disabled tests enabled and failures can be manually compared against +// expectations. +class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; + +XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { + // Test shouldn't fail if no module is created at all. +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { + // Use module() to lazily create an empty module, build it up, and verify no + // failures. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { + // Use module() to lazily create an empty module and build up an invalid + // module. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); + + *hlo_module.entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { + // Call CreateNewModule and build up a valid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { + // Call CreateNewModule and build up a invalid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); + + *module->entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndVerifyModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + ParseAndVerifyModule(hlo_string); + EXPECT_EQ(module().entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} + +RANDOM GARBAGE +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleBad + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[1234] add(x,y) +} +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c530591c6e5fe75658dd507d794f8b6a64442871 --- /dev/null +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +StatusOr BuildComputation() { + XlaBuilder b("computation"); + Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); + XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); + return b.Build( + OutfeedWithToken(GetTupleElement(infeed, 0) + + ConstantLiteral(&b, LiteralUtil::CreateR0(1)), + GetTupleElement(infeed, 1), scalar_s32, "")); +} + +void CompileAndExecute( + LocalExecutable* executable, int device_ordinal, LocalClient* client, + absl::Mutex* results_mutex, + std::vector>>* results) { + xla::ExecutableRunOptions execute_options; + execute_options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + execute_options.set_device_ordinal(device_ordinal); + execute_options.set_allocator( + xla::ClientLibrary::GetXlaService(client->platform()) + ->backend() + .memory_allocator()); + StatusOr result = executable->Run({}, execute_options); + { + absl::MutexLock lock(results_mutex); + results->emplace_back(device_ordinal, std::move(result)); + } +} + +void TestWithDeviceCount(const int device_count) { + // Run `device_count` copies of the XLA program built by BuildComputation. + TF_ASSERT_OK_AND_ASSIGN( + se::Platform* const platform, + perftools::gputools::MultiPlatformManager::PlatformWithName("Host")); + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + TF_ASSERT_OK_AND_ASSIGN( + LocalClient* const client, + xla::ClientLibrary::GetOrCreateLocalClient(client_options)); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{})); + std::vector threads; + absl::Mutex results_mutex; + std::vector>> results; + tensorflow::Env* env = tensorflow::Env::Default(); + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + tensorflow::Thread* t = env->StartThread( + tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal), + [&executable, device_ordinal, client, &results_mutex, &results] { + CompileAndExecute(executable.get(), device_ordinal, client, + &results_mutex, &results); + }); + threads.push_back(t); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(client->TransferToInfeedLocal( + LiteralUtil::CreateR0(device_ordinal * 100), device_ordinal)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, + client->TransferFromOutfeedLocal( + ShapeUtil::MakeShape(S32, {}), device_ordinal)); + EXPECT_EQ(outfeed, LiteralUtil::CreateR0(device_ordinal * 100 + 1)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + delete threads[device_ordinal]; + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(results[device_ordinal].second.status()); + } +} + +// NB! This test requires --xla_force_host_platform_device_count=4 + +TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); } + +TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); } + +TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); } + +TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); } +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 63491a90bf2634a53591e2ab431781f3c4237681..22fe4a2670e2e0e1fedc45036a1ceec19f44e42e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, /*padding=*/padding); CHECK(reducer == kAdd || reducer == kMax); @@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1303,11 +1308,19 @@ struct R1ReduceWindowTestData { /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + // The pattern generated by inclusive scan (cumsum/cumprod). {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, /*strides=*/{1}, /*pad_low=*/{4095}, /*pad_high=*/{0}, /*reducer=*/Reducer::kMax}, + + // The pattern generated by exclusive scan (cumsum/cumprod). + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4096}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( @@ -1361,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index d20dba028a586fa7c93c96dca03c77e3668fa644..b21dd56045e1dc11847e213852dea60cd033be7b 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -507,6 +507,36 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}}); + Literal updates = LiteralUtil::CreateR3({{{-10, 10}, {-40, 40}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, OneScalarIndex) { const char* hlo_text = R"( HloModule OneScalarIndex diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index a40c2d7de6eceea489004f5266d060f26da5d1a8..2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P( R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, // + R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, // R2Spec{ 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, // R2Spec{ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 5155f0c652c7c6dbba60c421159494fa28072090..2f18036ff4c5b0bfa28723fb181c33fa6995eb80 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -272,9 +272,11 @@ std::vector FindConstrainedUses( constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && - instruction->operand_count() == 2 && op_num == 0) { + instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value - // (two-operand) kSort instructions. + // (two-operand) kSort instructions. Since sort stability is not + // guaranteed, constrain keys of key-value sort not to have duplicates, + // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 181e5cbe290b0df0cf605cc4ef4b8a4945b3d367..bc433eac8fcb02087d8e4eb10f638c85dc141b22 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const float& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } @@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet key_set; + absl::flat_hash_set key_set; for (const int32& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 7abd8651d5ca272f9e82d797870a3bd6b1589615..8b1b9e151992296b9d022ae1d9d974eadd2074a8 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -763,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -// Test while nodes that share the while body computation. -// TODO(b/37245345): Fails on GPU backend. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { +TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index db5a824de08edeb81b5deb047507dc6158833008..a6e70eb6ca25ffac24a8ebaf0420238e109e4fad 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -83,7 +83,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap* parsed_results, + absl::flat_hash_map* parsed_results, absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; @@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; TF_ASSERT_OK(ParseOneProfileOutputLine( profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); @@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - gtl::FlatMap parsed_profile_lines; + absl::flat_hash_map parsed_profile_lines; for (auto while_body_profile_i = while_body_profile_start + 1; while_body_profile_i != while_body_profile_end; while_body_profile_i++) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index b53f89d63b1edb5fb01ae9e6e71385797ca0f904..60d25a6407476cddba77aadd1df2e3939f5e40ac 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -200,6 +200,15 @@ message DebugOptions { // among different algorithms. bool xla_gpu_crash_on_verification_failures = 101; + // Force the host platform to pretend that there are these many host + // "devices". All these devices are backed by the same threadpool. Defaults + // to 1. + // + // Setting this to anything other than 1 can increase overhead from context + // switching but we let the user override this behavior to help run tests on + // the host that run models in parallel across multiple devices. + int32 xla_force_host_platform_device_count = 102; + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index fda4c31298ebc8c906418afdb8127492b1c5d3f0..40ec1b0ba9b336f5b6407c79c8d63e31219f9b84 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { REGISTER_OP("XRTExecute") - .Attr("Ninputs: int") + .Attr("Ninputs: int >= 0") .Input("computation_handle: int64") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 09ab4ed95f91d9175cfa2bb555969a59b15762c4..b6dcfc4eb96316b5dad95a65b04d0ae69e4485f6 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -8,6 +8,10 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "raw_api_test_lib", @@ -57,7 +61,7 @@ tf_cuda_cc_test( size = "medium", srcs = [], args = ["--xla_test_device=XLA_GPU"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_gpu_device", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 2952feb16a8a60aecf16be87c9b800d314c4af58..f590fbf0d9d85e6e8b041f6719ab6a14ec9e2191 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, return equal; } +xla::XlaComputation OnePlusTwo() { + xla::XlaBuilder builder("OnePlusTwo"); + auto c0 = xla::ConstantR0(&builder, 1.0f); + auto c1 = xla::ConstantR0(&builder, 2.0f); + xla::Add(c0, c1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndScale() { xla::XlaBuilder builder("AddAndScale"); auto p0 = xla::Parameter(&builder, 0, @@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteZeroArg) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto result = ops::XRTExecute(root, c_handle, e_config, + std::initializer_list({})); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(3.0f); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; p0.set_device_ordinal(0); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 798f499870095043b77389d0f39306bd4d309259..fa06d351d4e64bfc2fc5e64c81c810185600000a 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -29,6 +29,7 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", "//tensorflow/contrib/copy_graph:copy_graph_py", @@ -60,7 +61,6 @@ py_library( "//tensorflow/contrib/learn", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", "//tensorflow/contrib/libsvm", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lite/python:lite", @@ -113,25 +113,23 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "//tensorflow/contrib/bigtable", + "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols "//tensorflow/contrib/kafka", + "//tensorflow/contrib/kinesis", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ], - "//conditions:default": [], }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ - "//tensorflow/contrib/kinesis", + "//tensorflow:with_ignite_support": [ + "//tensorflow/contrib/ignite", ], "//conditions:default": [], - }) + if_not_windows_cuda([ - "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols - ]) + if_not_windows([ - "//tensorflow/contrib/bigtable", # depends on bigtable - "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows - "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", - ]), + }), ) cc_library( @@ -140,7 +138,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/hadoop:dataset_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", @@ -155,17 +152,13 @@ cc_library( ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", ]) + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_kernels", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ "//tensorflow/contrib/kinesis:dataset_kernels", + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", ], - "//conditions:default": [], }), ) @@ -175,8 +168,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/data:dataset_ops_op_lib", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/hadoop:dataset_ops_op_lib", @@ -192,15 +183,16 @@ cc_library( "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", ] + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_ops_op_lib", + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", ], - "//conditions:default": [], }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ - "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + "//tensorflow:with_ignite_support": [ + "//tensorflow/contrib/ignite:dataset_ops_op_lib", ], "//conditions:default": [], }), diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9478e42b46f363c9ad673ade1ea1ceff27075ff0..f52a1a7babceeae93cdd2e5a93dad413a1d30191 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,14 +21,6 @@ from __future__ import print_function import os -from tensorflow.python.tools import component_api_helper -component_api_helper.package_hook( - parent_package_str=( - "tensorflow.contrib"), - child_package_str=( - "tensorflow_estimator.contrib.estimator")) -del component_api_helper - # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching @@ -63,7 +55,6 @@ from tensorflow.contrib import labeled_tensor from tensorflow.contrib import layers from tensorflow.contrib import learn from tensorflow.contrib import legacy_seq2seq -from tensorflow.contrib import linalg from tensorflow.contrib import linear_optimizer from tensorflow.contrib import lookup from tensorflow.contrib import losses diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index b3f5d92259df8475b205110dd3f0cee1cb5bde6f..9a8f62b9866bf0ac873ac299c963e2c3fc75b577 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -149,7 +149,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): num_devices = num_workers * num_gpus dev_list = ["/replica:0/task:0/device:CPU:0" for _ in range(num_devices)] - with self.test_session(): + with self.cached_session(): input_tensors = self._buildInitialVars(shape, dev_list) un_op = lambda x: math_ops.div( x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT)) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 6ea2db72c411f2f19a06ff108d6b63fc3bde352b..8c277b59e8f36034b99b1f5256473e7f434b624a 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -4,147 +4,6 @@ [deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is moving into TensorFlow core. -The new code location is `tensorflow/python/autograph`. +The new code location is `tensorflow/python/autograph`. Please refer to the +README.md file in that directory. ** - -IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). - -AutoGraph is a Python to TensorFlow compiler. - -With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md). - -For example, this Python function: - -``` -def f(x): - if x < 0: - x = -x - return x -``` - -would be converted to this: - -``` -def graph_mode_f(x): - with tf.name_scope('f'): - - def if_true(): - with tf.name_scope('if_true'): - x_1, = x, - x_1 = tf.negative(x_1) - return x_1, - - def if_false(): - with tf.name_scope('if_false'): - x_1, = x, - return x_1, - x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false) - return x -``` - -so you can use it like an op: - -``` -with tf.Graph().as_default(): - x = tf.constant(-1.0) - - converted_f = autograph.to_graph(f) - y = converted_f(x) - - with tf.Session() as sess: - print(sess.run(y)) - # Output: 1 -``` - -# Getting started - -Use AutoGraph in one of the following ways, described below: - - 1. Annotations (simpler) - 2. Functional API (more flexible) - -To get started, install the latest nightly TensorFlow build: - -```shell -pip install -U tf-nightly -``` - -Then import the `autograph` module from `tf.contrib`: - -``` -from tensorflow.contrib import autograph as ag -``` - -### Related links - -Articles: - - * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7) - -Interactive notebooks: - - * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb) - * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) - * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) - * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb) - * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb) - * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb) - * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb) - -## Using with annotations - -Annotating a function or class with `@convert` converts it in place: - -``` -@ag.convert() -def f(x): - if x < 0: - x = -x - return x -``` - -... so that it always outputs TensorFlow code: - -``` -with tf.Graph().as_default(): - x = tf.constant(-1) - - y = f(x) - - with tf.Session() as sess: - print(sess.run(y)) - # Output: 1 -``` - -## Using the functional API - -The functional API allows you to convert an existing function, class or object after it was defined: - -``` -converted_f = ag.to_graph(f) - -print(converted_f(tf.constant(-1))) -# Output: Tensor - -print(f(-1)) -# Output: 1 -``` - -You can use the functional API to inspect the generated code as well: - -``` -print(ag.to_code(f)) -# Output: -``` - -## Filing bugs and feature requests - -### Reporting a bug - - - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. - - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. - - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you. - -### Requesting a feature - -If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there. diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index b27a19b16c08cb588b45949105a6399623e766e1..648f3ebb05646a66144bcb118347cbc391909409 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -7,64 +7,6 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "batch_scheduler_hdrs", - hdrs = ["batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", - ], -) - -cc_library( - name = "batch_scheduler", - hdrs = ["batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:batch_scheduler", - ], -) - -cc_library( - name = "shared_batch_scheduler_hdrs", - hdrs = ["shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs", - ], -) - -cc_library( - name = "shared_batch_scheduler", - hdrs = ["shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:shared_batch_scheduler", - ], - alwayslink = 1, -) - -cc_library( - name = "adaptive_shared_batch_scheduler", - hdrs = ["adaptive_shared_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", - ], -) - -cc_library( - name = "serial_device_batch_scheduler", - hdrs = ["serial_device_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", - ], -) - -cc_library( - name = "basic_batch_scheduler", - hdrs = ["basic_batch_scheduler.h"], - deps = [ - "//tensorflow/core/kernels/batching_util:basic_batch_scheduler", - ], -) - load( "//tensorflow:tensorflow.bzl", "py_test", diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 78468145469df216344bc00f116add250dc51dd3..01ee8703a93836d607ee9b765c51c79fe3bb974f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase): def testBasicBatch(self): """Tests that a single batched tensor executes together and only once.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase): def testBatchWithPadding(self): """Test that batching with padding up to an allowed batch size works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase): def testMultipleBatch(self): """Tests that multiple batched tensors execute together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, _, _ = batch_ops.batch( @@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase): def testIllegalBatchDifferentDim0Sizes(self): """Tests illegally feeding tensors with different dim0 sizes.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( @@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatch(self): """Tests that batch and unbatch work together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchV1Decorated(self): """Tests that the batch_function_v1 decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) def computation(in_t): return in_t + 1 @@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: # TODO(apassos): Removing this line causes test flakiness! Ideally should # be investigated. default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable @@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase): def testBatchDecoratedWithCapturedInput(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOp(self): """Tests that the batch_function op works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): @@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithCapturedInput(self): """Tests that batch_function op works with captured input.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithInputError(self): """Tests that batch_function op works with error in the inputs.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) @@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @batch_ops.batch_function(1, 10, 100000) def computation(in_t): @@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h deleted file mode 100644 index bf6b7083612018eecf0d1784e60cbbf0c5796fef..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/serial_device_batch_scheduler.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ - -#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD deleted file mode 100644 index 7cb2d8079bd18660f72eab92654629434ce4d6a5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/test_util/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -# Description: Utilities to aid testing. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "fake_clock_env", - testonly = 1, - hdrs = ["fake_clock_env.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/kernels/batching_util:fake_clock_env", - ], -) diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD deleted file mode 100644 index 8f81b6702f2807d7da7e72190ce2d86b28e52113..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/util/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Description: Utilities. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "periodic_function_dynamic", - hdrs = ["periodic_function.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", - "//third_party/eigen3", - ], -) - -cc_library( - name = "periodic_function", - visibility = ["//visibility:public"], - deps = [ - ":periodic_function_dynamic", - "//tensorflow/core/kernels/batching_util:periodic_function", - ], -) diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h deleted file mode 100644 index aa2ed0a385125fa090a7a56b6339a87eb2d57b1f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/util/periodic_function.h +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ - -#include "tensorflow/core/kernels/batching_util/periodic_function.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 9e6a146f67796466202cc5074ddd25e4c2b083a6..13215ffabf3a956d3f83697f867457b2fa72e7c9 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -42,7 +42,7 @@ class ExpectationImportanceSampleTest(test.TestCase): def test_normal_integral_mean_and_var_correctly_estimated(self): n = int(1e6) - with self.test_session(): + with self.cached_session(): mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64) mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64) sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64) @@ -72,7 +72,7 @@ class ExpectationImportanceSampleTest(test.TestCase): # Test that importance sampling can correctly estimate the probability that # the product of components in a MultivariateNormal are > 0. n = 1000 - with self.test_session(): + with self.cached_session(): p = mvn_diag_lib.MultivariateNormalDiag( loc=[0.], scale_diag=[1.0, 1.0]) q = mvn_diag_lib.MultivariateNormalDiag( @@ -99,7 +99,7 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase): def test_normal_distribution_second_moment_estimated_correctly(self): # Test the importance sampled estimate against an analytical result. n = int(1e6) - with self.test_session(): + with self.cached_session(): mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64) mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64) sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64) @@ -127,7 +127,7 @@ class GetSamplesTest(test.TestCase): """Test the private method 'get_samples'.""" def test_raises_if_both_z_and_n_are_none(self): - with self.test_session(): + with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = None n = None @@ -136,7 +136,7 @@ class GetSamplesTest(test.TestCase): _get_samples(dist, z, n, seed) def test_raises_if_both_z_and_n_are_not_none(self): - with self.test_session(): + with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = dist.sample(seed=42) n = 1 @@ -145,7 +145,7 @@ class GetSamplesTest(test.TestCase): _get_samples(dist, z, n, seed) def test_returns_n_samples_if_n_provided(self): - with self.test_session(): + with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = None n = 10 @@ -154,7 +154,7 @@ class GetSamplesTest(test.TestCase): self.assertEqual((10,), z.get_shape()) def test_returns_z_if_z_provided(self): - with self.test_session(): + with self.cached_session(): dist = normal_lib.Normal(loc=0., scale=1.) z = dist.sample(10, seed=42) n = None @@ -166,7 +166,7 @@ class GetSamplesTest(test.TestCase): class ExpectationTest(test.TestCase): def test_works_correctly(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6]) p = normal_lib.Normal(loc=x, scale=1.) @@ -213,7 +213,7 @@ class ExpectationTest(test.TestCase): rtol=0.05, atol=0.) def test_docstring_example_normal(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_draws = int(1e5) mu_p = constant_op.constant(0.) mu_q = constant_op.constant(1.) diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 9afe3df585fed6dc7feed1c364a4dac72041257d..18d40fc1dff8e7c9aefffbe3ceba770598a42096 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ 'expectation', @@ -66,7 +67,7 @@ def expectation_importance_sampler(f, shape broadcastable to `q.batch_shape`. For example, `log_p` works "just like" `sampling_dist_q.log_prob`. sampling_dist_q: The sampling distribution. - `tf.contrib.distributions.Distribution`. + `tfp.distributions.Distribution`. `float64` `dtype` recommended. `log_p` and `q` should be supported on the same set. z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. @@ -141,7 +142,7 @@ def expectation_importance_sampler_logspace( shape broadcastable to `q.batch_shape`. For example, `log_p` works "just like" `q.log_prob`. sampling_dist_q: The sampling distribution. - `tf.contrib.distributions.Distribution`. + `tfp.distributions.Distribution`. `float64` `dtype` recommended. `log_p` and `q` should be supported on the same set. z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. @@ -188,6 +189,12 @@ def _logspace_mean(log_values): return log_mean_of_values +@deprecation.deprecated( + '2018-10-01', + 'The tf.contrib.bayesflow library has moved to ' + 'TensorFlow Probability (https://github.com/tensorflow/probability). ' + 'Use `tfp.monte_carlo.expectation` instead.', + warn_once=True) def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). @@ -236,17 +243,17 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Example Use: ```python - bf = tf.contrib.bayesflow - ds = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Monte-Carlo approximation of a reparameterized distribution, e.g., Normal. num_draws = int(1e5) - p = ds.Normal(loc=0., scale=1.) - q = ds.Normal(loc=1., scale=2.) - exact_kl_normal_normal = ds.kl_divergence(p, q) + p = tfd.Normal(loc=0., scale=1.) + q = tfd.Normal(loc=1., scale=2.) + exact_kl_normal_normal = tfd.kl_divergence(p, q) # ==> 0.44314718 - approx_kl_normal_normal = bf.expectation( + approx_kl_normal_normal = tfp.monte_carlo.expectation( f=lambda x: p.log_prob(x) - q.log_prob(x), samples=p.sample(num_draws, seed=42), log_prob=p.log_prob, @@ -260,9 +267,9 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, num_draws = int(1e5) p = ds.Gamma(concentration=1., rate=1.) q = ds.Gamma(concentration=2., rate=3.) - exact_kl_gamma_gamma = ds.kl_divergence(p, q) + exact_kl_gamma_gamma = tfd.kl_divergence(p, q) # ==> 0.37999129 - approx_kl_gamma_gamma = bf.expectation( + approx_kl_gamma_gamma = tfp.monte_carlo.expectation( f=lambda x: p.log_prob(x) - q.log_prob(x), samples=p.sample(num_draws, seed=42), log_prob=p.log_prob, @@ -278,7 +285,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, KL-divergence, the following is preferred: ```python - approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence( + approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence( f=bf.kl_reverse, p_log_prob=q.log_prob, q=p, diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index f33eaf7e3df356e10939f591ef75cb4f17978144..2c44abed5e1955cc666273e97e6b2378766f13d2 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -203,7 +203,7 @@ def interleave_fn(index): start = tf.string_join(['training_data_', start_idx_str]) end = tf.string_join(['training_data_', end_idx_str]) return table.scan_range(start_idx, end_idx, columns=columns) -ds = ds.apply(tf.contrib.data.parallel_interleave( +ds = ds.apply(tf.data.experimental.parallel_interleave( interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1)) ``` @@ -249,7 +249,7 @@ def make_row_key_dataset(): - ... - fake-data-23498103 """ - counter_dataset = tf.contrib.data.Counter() + counter_dataset = tf.data.experimental.Counter() width = 8 row_key_prefix = 'fake-data-' ds = counter_dataset.map(lambda index: tf.as_string(index, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 11f530e82a186f410bc505de7fbf1b478240c340..2c6317157d25908c1ff66fc10bd188d93f040521 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -28,6 +28,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { DatasetBase** output) override { BigtableTableResource* table; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + core::ScopedUnref scoped_unref(table); std::vector column_families; std::vector columns; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index 5cab729d9c16f144ec5671ad775f384ad79ad9e0..92a3658667293a934cf3c25510d825d4ef4a993d 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -31,6 +31,7 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource, std::move(prefix)); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 4dc4647bd24f3a957bc93a9ed8c81b3c7deb6a47..bd8805a3827c6bf9305e5636ce1e89e79cd5cc6d 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -34,6 +34,7 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 736775bdac10da757190c0b2e4a7672d55edf317..01608dc6bc07890c3a59577ef31c90c2694e6a87 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -38,6 +38,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { BigtableTableResource* resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref scoped_unref(resource); OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), errors::InvalidArgument( diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index 208b7b3e08692c00c1fd879c2a02641fb05ff639..9b60e0a6672c2e8468aa30c671b45be853c092f0 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -28,6 +28,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { BigtableTableResource* resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref scoped_unref(resource); *output = new Dataset(ctx, resource); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 9407855fe88db9faec1949db98a725e5a1cd9f38..688289a4e24afa914a5c1d26c11cac2dc568266b 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -67,6 +67,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { BigtableTableResource* resource; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + core::ScopedUnref scoped_unref(resource); const uint64 num_outputs = columns.size() + 1; std::vector output_shapes; diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index e36f7f32c61b50047c0d9137427f2a24462b1c9a..316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -61,7 +61,7 @@ class BigtableOpsTest(test.TestCase): n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) for i in range(3): @@ -84,7 +84,7 @@ class BigtableOpsTest(test.TestCase): expected_keys.reverse() expected_values = list(self.COMMON_VALUES) expected_values.reverse() - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) for i in range(3): @@ -125,7 +125,7 @@ class BigtableOpsTest(test.TestCase): expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) expected_tuples = zip(expected_keys, expected_values) - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) for i, elem in enumerate(expected_tuples): @@ -144,7 +144,7 @@ class BigtableOpsTest(test.TestCase): itr = ds.make_initializable_iterator() n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) output = sess.run(n) @@ -163,7 +163,7 @@ class BigtableOpsTest(test.TestCase): def runSampleKeyPairsTest(self, ds, expected_key_pairs): itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) for i, elems in enumerate(expected_key_pairs): @@ -219,7 +219,7 @@ class BigtableOpsTest(test.TestCase): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") itr = ds.make_initializable_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -227,7 +227,7 @@ class BigtableOpsTest(test.TestCase): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") itr = ds.make_initializable_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -235,7 +235,7 @@ class BigtableOpsTest(test.TestCase): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) @@ -253,7 +253,7 @@ class BigtableOpsTest(test.TestCase): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self._writeCommonValues(sess) sess.run(itr.initializer) expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 3e1b6228673fbdcb5a228a11532d29e6b2c817dc..7c87b0daeb09950cc44c51f49c16534d413f0376 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -31,8 +31,8 @@ from six import iteritems from six import string_types from tensorflow.contrib.bigtable.ops import gen_bigtable_ops -from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.util import loader +from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -228,7 +228,7 @@ class BigtableTable(object): """Retrieves a sampling of row keys from the Bigtable table. This dataset is most often used in conjunction with - `tf.contrib.data.parallel_interleave` to construct a set of ranges for + `tf.data.experimental.parallel_interleave` to construct a set of ranges for scanning in parallel. Returns: @@ -575,7 +575,7 @@ def _normalize_columns(columns, provided_kwargs): return normalized -class _BigtableKeyDataset(dataset_ops.Dataset): +class _BigtableKeyDataset(dataset_ops.DatasetSource): """_BigtableKeyDataset is an abstract class representing the keys of a table. """ @@ -645,7 +645,7 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset): table=self._table._resource) # pylint: disable=protected-access -class _BigtableLookupDataset(dataset_ops.Dataset): +class _BigtableLookupDataset(dataset_ops.DatasetSource): """_BigtableLookupDataset represents a dataset that retrieves values for keys. """ @@ -678,7 +678,7 @@ class _BigtableLookupDataset(dataset_ops.Dataset): columns=self._columns) -class _BigtableScanDataset(dataset_ops.Dataset): +class _BigtableScanDataset(dataset_ops.DatasetSource): """_BigtableScanDataset represents a dataset that retrieves keys and values. """ @@ -715,7 +715,7 @@ class _BigtableScanDataset(dataset_ops.Dataset): probability=self._probability) -class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset): +class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table. """ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 5fcb19a47aac492d49b0d8e99af5699bae2ad9f0..14b6fc4ac26f74f54628ae37ad6437c7d3e8caba 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -173,6 +173,7 @@ py_library( py_test( name = "dnn_tree_combined_estimator_test", size = "medium", + timeout = "long", srcs = ["dnn_tree_combined_estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 78232fa0a6e2311c13d4f35acffc3486a9a28803..a3df272e6924792128fc38fd153b9527b58b486e 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -51,6 +51,7 @@ def make_custom_export_strategy(name, feature_columns: A list of feature columns. export_input_fn: A function that takes no arguments and returns an `InputFnOps`. + use_core_columns: A boolean, whether core feature columns were used. Returns: An `ExportStrategy`. @@ -195,8 +196,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, matching_id = categorical_test.value.add() matching_id.int64_value = split.feature_id node.custom_left_child_test.Pack(categorical_test) + elif (node_type == "oblivious_dense_float_binary_split" or + node_type == "oblivious_categorical_id_binary_split"): + raise ValueError("Universal tree format doesn't support oblivious " + "trees") else: - raise ValueError("Unexpected node type %s", node_type) + raise ValueError("Unexpected node type %s" % node_type) node.left_child_id.value = split.left_id node.right_child_id.value = split.right_id return model_and_features @@ -228,6 +233,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + num_sparse_float] + elif node_type == "oblivious_dense_float_binary_split": + split = tree_node.oblivious_dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "oblivious_categorical_id_binary_split": + split = tree_node.oblivious_categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] elif node_type == "categorical_id_set_membership_binary_split": split = tree_node.categorical_id_set_membership_binary_split split_column = feature_names[split.feature_column + num_dense_floats + @@ -236,7 +248,7 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, assert tree_node.node_metadata.gain == 0 continue else: - raise ValueError("Unexpected split type %s", node_type) + raise ValueError("Unexpected split type %s" % node_type) # Apply shrinkage factor. It is important since it is not always uniform # across different trees. sums[split_column] += ( diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 3b28ed77f325b3f8b09fe6b9d2776eff82ff53a7..8edb5d6c640611bbb90d7731b2fea4354e125563 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -579,13 +579,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { const int end_index = partition_boundaries[non_empty_partitions[root_idx]][j + 1] .start_index; - CHECK(bucket_ids_and_dimensions(start_index, 1) == - bucket_ids_and_dimensions(end_index - 1, 1)) - << "For bucket " << bucket_ids_and_dimensions(start_index, 0) - << " the dimension was " - << bucket_ids_and_dimensions(start_index, 1) << " and for " - << bucket_ids_and_dimensions(end_index - 1, 0) << " " - << bucket_ids_and_dimensions(end_index - 1, 1); if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) @@ -746,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; - std::vector non_empty_partitions; - for (int i = 0; i < partition_ids.size() - 1; ++i) { + partition_boundaries.push_back(0); + for (int i = 1; i < partition_ids.size(); ++i) { // Make sure the input is sorted by partition_ids; - CHECK_LE(partition_ids(i), partition_ids(i + 1)); - if (i == 0 || partition_ids(i) != partition_ids(i - 1)) { + OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i), + errors::InvalidArgument("Partition IDs must be sorted.")); + if (partition_ids(i) != partition_ids(i - 1)) { partition_boundaries.push_back(i); - // Some partitions might only have bias feature. We don't want to split - // those so check that the partition has at least 2 features. - if (partition_ids(i) == partition_ids(i + 1)) { - non_empty_partitions.push_back(partition_boundaries.size() - 1); - } } } - if (partition_ids.size() > 0) { - partition_boundaries.push_back(partition_ids.size()); + std::vector non_empty_partitions; + partition_boundaries.push_back(partition_ids.size()); + for (int i = 0; i < partition_boundaries.size() - 1; ++i) { + // We want to ignore partitions with only the bias term. + if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) { + non_empty_partitions.push_back(i); + } } int num_elements = non_empty_partitions.size(); Tensor* output_partition_ids_t = nullptr; @@ -862,6 +856,15 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(state->feature_column_group_id()); + CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id) + << "Unexpected feature ID selected. " + << "Start feature ID: [" << start_index << "] " + << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1) + << "\nBest feature ID: [" << best_feature_idx << "] " + << feature_ids(best_feature_idx, 0) << ", " + << feature_ids(best_feature_idx, 1) + << "\nPartition IDS: " << partition_ids(start_index) << " " + << partition_ids(best_feature_idx); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 35d727482bf631f2fe14e02c1ec4b75a763e8615..4da25298cb82093ac501997cc21c48265df06860 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -_BIAS_FEATURE_ID = -1 +_BIAS_FEATURE_ID = int(dtypes.int64.min) class EqualitySplitHandler(base_split_handler.BaseSplitHandler): diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 94ea7bc2eb7b098a0628683167510bf4e3c2426e..a2f708081a4b484d649b5d09b172c2c60db69aeb 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -170,7 +170,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testObliviousFeatureSplitGeneration(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 1 | 1 | @@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testLastOneEmpty(self): + with self.cached_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 0 | 1,2 | + # i1 | (-0.5, 0.07) | 0 | | + # i2 | (1.2, 0.2) | 0 | 2 | + # i3 | (4.0, 0.13) | 1 | | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [0, 0, 0, 1] + indices = [[0, 0], [0, 1], [2, 0]] + values = array_ops.constant([1, 2, 2], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([0], partitions) + + # Check the split on partition 0. + # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1) + expected_left_weight = -0.9848484848484846 + + # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1) + expected_left_gain = 1.2803030303030298 + + # -(-0.5 + 0.1) / (0.07 + 1) + expected_right_weight = 0.37383177570093457 + + # (-0.5 + 0.1) ** 2 / (0.07 + 1) + expected_right_gain = 0.14953271028037385 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain = 0.46043165467625885 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(2, split_node.feature_id) + + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index c7eb2493a8ba56943740326cf68ad6b3a91f67c4..8531e97f90236b8e5eb64bc0f4c9bb3b674f35cd 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._num_quantiles = num_quantiles - self._max_tree_depth = variables.Variable( + self._max_tree_depth = variables.VariableV1( initial_value=self._learner_config.constraints.max_tree_depth) - self._attempted_trees = variables.Variable( + self._attempted_trees = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, name="attempted_trees") - self._finalized_trees = variables.Variable( + self._finalized_trees = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, name="finalized_trees") @@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object): fc_name_idx += 1 # Create ensemble stats variables. - num_layer_examples = variables.Variable( + num_layer_examples = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layer_examples", trainable=False) - num_layer_steps = variables.Variable( + num_layer_steps = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layer_steps", trainable=False) - num_layers = variables.Variable( + num_layers = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layers", trainable=False) - active_tree = variables.Variable( + active_tree = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="active_tree", trainable=False) - active_layer = variables.Variable( + active_layer = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="active_layer", trainable=False) # Variable that becomes false once bias centering is done. - continue_centering = variables.Variable( + continue_centering = variables.VariableV1( initial_value=self._center_bias, name="continue_centering", trainable=False) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 73e41bc4571cabb51ee96812c01f0db7c0dfdd3c..6d20a2e7f482953481fb1effe4c6e2e5a300786f 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -86,7 +86,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testExtractFeatures(self): """Tests feature extraction.""" - with self.test_session(): + with self.cached_session(): features = {} features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32) features["sparse_float"] = sparse_tensor.SparseTensor( @@ -128,7 +128,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testExtractFeaturesWithTransformation(self): """Tests feature extraction.""" - with self.test_session(): + with self.cached_session(): features = {} features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32) features["sparse_float"] = sparse_tensor.SparseTensor( @@ -178,7 +178,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testExtractFeaturesFromCoreFeatureColumns(self): """Tests feature extraction when using core columns.""" - with self.test_session(): + with self.cached_session(): features = {} # Sparse float column does not exist in core, so only dense numeric and # categorical. @@ -213,7 +213,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefNoBiasCentering(self): """Tests the train function running on chief without bias centering.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -316,7 +316,7 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertProtoEquals(expected_tree, output.trees[0]) def testObliviousDecisionTreeAsWeakLearner(self): - with self.test_session(): + with self.cached_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -473,7 +473,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefSparseAndDense(self): """Tests the train function with sparse and dense features.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -580,7 +580,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -685,7 +685,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefWithBiasCentering(self): """Tests the train function running on chief with bias centering.""" - with self.test_session(): + with self.cached_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -757,7 +757,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnNonChiefNoBiasCentering(self): """Tests the train function running on worker without bias centering.""" - with self.test_session(): + with self.cached_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -821,7 +821,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnNonChiefWithCentering(self): """Tests the train function running on worker with bias centering.""" - with self.test_session(): + with self.cached_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -885,7 +885,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testPredictFn(self): """Tests the predict function.""" - with self.test_session() as sess: + with self.cached_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -939,7 +939,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testPredictFnWithLeafIndexAdvancedLeft(self): """Tests the predict function with output leaf ids.""" - with self.test_session() as sess: + with self.cached_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -1051,7 +1051,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") @@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1155,7 +1155,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnMulticlassDiagonalHessian(self): """Tests the GBDT train for multiclass diagonal hessian.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") @@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1259,7 +1259,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnMulticlassTreePerClass(self): """Tests the GBDT train for multiclass tree per class strategy.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") @@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1374,7 +1374,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): """Tests the train function running on chief with feature selection.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1493,7 +1493,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefFeatureSelectionWithGoodSplits(self): """Tests the train function running on chief with feature selection.""" - with self.test_session() as sess: + with self.cached_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() @@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1610,7 +1610,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): """Tests the train function running on chief with feature selection.""" - with self.test_session() as sess: + with self.cached_session() as sess: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() @@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1720,7 +1720,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testResetModelBeforeAndAfterSplit(self): """Tests whether resetting works.""" - with self.test_session(): + with self.cached_session(): # First build a small tree and train it to verify training works. ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") @@ -1854,7 +1854,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testResetModelNonChief(self): """Tests the reset function on a non-chief worker.""" - with self.test_session(): + with self.cached_session(): # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -1930,7 +1930,7 @@ class GbdtTest(test_util.TensorFlowTestCase): def testResetModelWithCenterBias(self): """Tests the reset function running on chief with bias centering.""" - with self.test_session(): + with self.cached_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py index ccb8509c0347f9c9b6f1e8f4f620230aac9a6c2d..cc22504c8f34e7df30c5676a436d31b452bc9496 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py @@ -45,7 +45,7 @@ class LossesTest(test_util.TensorFlowTestCase): eps = 0.2 - with self.test_session(): + with self.cached_session(): predictions_tensor = constant_op.constant( prediction_logits, dtype=dtypes.float32) loss_for_positives, _ = losses.per_example_exp_loss( @@ -84,7 +84,7 @@ class LossesTest(test_util.TensorFlowTestCase): predictions = np.array( [[0.123], [23.2], [233], [52], [3]], dtype=np.float32) - with self.test_session(): + with self.cached_session(): loss_tensor, _ = losses.per_example_squared_loss(labels, weights, predictions) diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 150d734db6cdd8023ab6d91a49872f657bcdbdea..94b7f4f867655bf7fdf94e8488eeae7088c41622 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -37,6 +37,7 @@ Checkpoint management: Saving and restoring Python state: @@NumpyState +@@PythonStateWrapper """ from __future__ import absolute_import @@ -45,6 +46,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.python_state import NumpyState +from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 9b11035b6d277851ea0a0071062bf5cf6b6b2185..302d5cfb79a08b6adf52ebd44533152c5454eadc 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import functools +import six import numpy @@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase): # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making # ndarrays checkpointable natively and using standard checkpointable list # tracking. - if isinstance(value, numpy.ndarray): + if isinstance(value, (numpy.ndarray, numpy.generic)): try: existing = super(NumpyState, self).__getattribute__(name) existing.array = value @@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase): super(NumpyState, self).__setattr__(name, value) -class _NumpyWrapper(base.CheckpointableBase): +@six.add_metaclass(abc.ABCMeta) +class PythonStateWrapper(base.CheckpointableBase): + """Wraps a Python object for storage in an object-based checkpoint.""" + + @abc.abstractmethod + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the object.""" + + @abc.abstractmethod + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the object.""" + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "py_state": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } + + +class _NumpyWrapper(PythonStateWrapper): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase): self.array = array def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the array.""" + """Callback to serialize the array.""" string_file = BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) @@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase): return serialized def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the array.""" + """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "array": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 0439a4755e36fc3be6e065d18d3e835feda8aab3..45494351ff4e6c8c75634d8563c3fb63c6089036 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase): save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) + save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. + save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = python_state.NumpyState() @@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase): loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() @@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase): loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(4, load_state.c) def testNoGraphPollution(self): graph = ops.Graph() diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1056894f18f1ec19a598dfbd1161d7f9bea7e94f..f4a8e16c99f464b813a98e981579bd0ff53bd464 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -60,6 +60,7 @@ class TPUClusterResolver(ClusterResolver): if (self._tpu == compat.as_bytes('') or self._tpu == compat.as_bytes('local') or self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('localhost:')) or self._tpu.startswith(compat.as_bytes('grpc://'))): return False return True diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index ebcabb42230c86cfb2ae280c83092b9006033e7d..244683765a75626acd932ef8a10d8e5b6639ebb0 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -1,6 +1,16 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.5) +if(WIN32) + if(${CMAKE_VERSION} VERSION_LESS "3.8") + message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.") + else() + if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64") + message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.") + endif() + endif() +endif() + # Project project(tensorflow C CXX) @@ -30,7 +40,6 @@ endif() option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) -option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON) option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON) option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON) @@ -218,10 +227,6 @@ if (tensorflow_WIN_CPU_SIMD_OPTIONS) endif() endif() -if (tensorflow_ENABLE_JEMALLOC_SUPPORT) - add_definitions(-DTENSORFLOW_USE_JEMALLOC -DJEMALLOC_EXPORT=) -endif() - # External dependencies include(zlib) include(gif) @@ -329,12 +334,6 @@ if(tensorflow_ENABLE_GRPC_SUPPORT) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) endif() endif() -if(tensorflow_ENABLE_JEMALLOC_SUPPORT) - include(jemalloc) - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${jemalloc_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc) - include_directories(${jemalloc_INCLUDE_DIRS}) -endif() if(tensorflow_ENABLE_SNAPPY_SUPPORT) include(snappy) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES}) @@ -353,7 +352,7 @@ endif() # MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) - add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) + add_definitions(-DINTEL_MKL -DEIGEN_USE_VML -DENABLE_MKL) include(mkl) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination) @@ -363,9 +362,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) - else (tensorflow_ENABLE_MKLDNN_SUPPORT) - add_definitions(-DINTEL_MKL_ML_ONLY) - endif() + endif(tensorflow_ENABLE_MKLDNN_SUPPORT) endif (tensorflow_ENABLE_MKL_SUPPORT) if (tensorflow_ENABLE_GPU) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 0b79f718d4823a987e02804f59a432ee46d0ada3..84c679162c3ed8ffc9babcd3af583b26fb62c2d6 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -1,6 +1,10 @@ TensorFlow CMake build ====================== +CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all +platforms. For details, see the +[TensorFlow install guide](https://www.tensorflow.org/install/). + This directory contains CMake files for building TensorFlow on Microsoft Windows. [CMake](https://cmake.org) is a cross-platform tool that can generate build scripts for multiple build systems, including Microsoft @@ -13,7 +17,7 @@ Linux. Current Status -------------- -CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/install_windows) +CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows) for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations @@ -104,180 +108,177 @@ ops or APIs. Step-by-step Windows build ========================== -1. Install the prerequisites detailed above, and set up your environment. - - * The following commands assume that you are using the Windows Command - Prompt (`cmd.exe`). You will need to set up your environment to use the - appropriate toolchain, i.e. the 64-bit tools. (Some of the binary targets - we will build are too large for the 32-bit tools, and they will fail with - out-of-memory errors.) The typical command to do set up your - environment is: - - ``` - D:\temp> "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat" - ``` - - * When building with GPU support after installing the CUDNN zip file from NVidia, append its - bin directory to your PATH environment variable. - In case TensorFlow fails to find the CUDA dll's during initialization, check your PATH environment variable. - It should contain the directory of the CUDA dlls and the directory of the CUDNN dll. - For example: - - ``` - D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin - D:\local\cuda\bin - ``` - - * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. - - In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. - It should contain the directory of the MKL dlls. For example: - - ``` - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt - ``` - - - * We assume that `cmake` and `git` are installed and in your `%PATH%`. If - for example `cmake` is not in your path and it is installed in - `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory - to your `%PATH%` as follows: - - ``` - D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" - ``` - -2. Clone the TensorFlow repository and create a working directory for your - build: - - ``` - D:\temp> git clone https://github.com/tensorflow/tensorflow.git - D:\temp> cd tensorflow\tensorflow\contrib\cmake - D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build - D:\temp\tensorflow\tensorflow\contrib\cmake> cd build - D:\temp\tensorflow\tensorflow\contrib\cmake\build> - ``` - -3. Invoke CMake to create Visual Studio solution and project files. - - **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment - variable. The other paths are for illustrative purposes only, and may - be different on your platform. The `^` character is a line continuation - and must be the last character on each line. - - ``` - D:\...\build> cmake .. -A x64 -DCMAKE_BUILD_TYPE=Release ^ - More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ - More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ - More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib - ``` - To build with GPU support add "^" at the end of the last line above following with: - ``` - More? -Dtensorflow_ENABLE_GPU=ON ^ - More? -DCUDNN_HOME="D:\...\cudnn" - ``` - To build with MKL support add "^" at the end of the last line above following with: - - ``` - More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ - More? -DMKL_HOME="D:\...\compilers_and_libraries" - ``` - - To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: - - ``` - More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX - ``` - - Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build - configuration that you choose when invoking `msbuild`. The known-good - values are `Release` and `RelWithDebInfo`. The `Debug` build type is - not currently supported, because it relies on a `Debug` library for - Python (`python35d.lib`) that is not distributed by default. - - There are various options that can be specified when generating the - solution and project files: - - * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the - `CMAKE_BUILD_TYPE` option must match the build configuration that you - choose when invoking MSBuild in step 4. The known-good values are - `Release` and `RelWithDebInfo`. The `Debug` build type is not currently - supported, because it relies on a `Debug` library for Python - (`python35d.lib`) that is not distributed by default. - - * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can - build a small subset of the kernels for a faster build by setting this - option to `OFF`. - - * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate - project files for a simple C++ - [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). - - * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. Generate - project files for building a PIP package containing the TensorFlow runtime - and its Python bindings. - - * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include - gRPC support and the distributed client and server code in the TensorFlow - runtime. - - * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include - SSL support (for making secure HTTP requests) in the TensorFlow runtime. - This support is incomplete, and will be used for Google Cloud Storage - support. - - * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include - GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1. - CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn. - - * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests. - There are many of them and building will take a few hours. - After cmake, build and execute the tests with - ``` - MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on - serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). - CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. - - * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. - - -4. Invoke MSBuild to build TensorFlow. - - To build the C++ example program, which will be created as a `.exe` - executable in the subdirectory `.\Release`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj - D:\...\build> Release\tf_tutorials_example_trainer.exe - ``` - - To build the PIP package, which will be created as a `.whl` file in the - subdirectory `.\tf_python\dist`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj - ``` - +1. Install the prerequisites detailed above, and set up your environment. + + * When building with GPU support after installing the CUDNN zip file from + NVidia, append its bin directory to your PATH environment variable. In + case TensorFlow fails to find the CUDA dll's during initialization, + check your PATH environment variable. It should contain the directory of + the CUDA dlls and the directory of the CUDNN dll. For example: + + ``` + D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin + D:\local\cuda\bin + ``` + + * When building with MKL support after installing + [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin + directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, + check your PATH environment variable. It should contain the directory of + the MKL dlls. For example: + + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If + for example `cmake` is not in your path and it is installed in + `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory + to your `%PATH%` as follows: + + ``` + D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" + ``` + +2. Clone the TensorFlow repository and create a working directory for your + build: + + ``` + D:\temp> git clone https://github.com/tensorflow/tensorflow.git + D:\temp> cd tensorflow\tensorflow\contrib\cmake + D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build + D:\temp\tensorflow\tensorflow\contrib\cmake> cd build + D:\temp\tensorflow\tensorflow\contrib\cmake\build> + ``` + +3. Invoke CMake to create Visual Studio solution and project files. + + **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment + variable. The other paths are for illustrative purposes only, and may be + different on your platform. The `^` character is a line continuation and + must be the last character on each line. + + ``` + D:\...\build> cmake .. -A x64 -Thost=x64 -DCMAKE_BUILD_TYPE=Release ^ + More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ + More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ + More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib + ``` + + To build with GPU support add "^" at the end of the last line above + following with: `More? -Dtensorflow_ENABLE_GPU=ON ^ More? + -DCUDNN_HOME="D:\...\cudnn"` To build with MKL support add "^" at the end of + the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + + ``` + More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX + ``` + + Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build + configuration that you choose when invoking `msbuild`. The known-good values + are `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + The `-Thost=x64` flag will ensure that the 64 bit compiler and linker is + used when building. Without this flag, MSBuild will use the 32 bit toolchain + which is prone to compile errors such as "compiler out of heap space". + + There are various options that can be specified when generating the solution + and project files: + + * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the + `CMAKE_BUILD_TYPE` option must match the build configuration that you + choose when invoking MSBuild in step 4. The known-good values are + `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can + build a small subset of the kernels for a faster build by setting this + option to `OFF`. + + * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate + project files for a simple C++ + [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). + + * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. + Generate project files for building a PIP package containing the + TensorFlow runtime and its Python bindings. + + * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include + gRPC support and the distributed client and server code in the + TensorFlow runtime. + + * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + SSL support (for making secure HTTP requests) in the TensorFlow runtime. + This support is incomplete, and will be used for Google Cloud Storage + support. + + * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include GPU + support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and + CUDNN 5.1. CMake will expect the location of CUDNN in + -DCUDNN_HOME=path_you_unzipped_cudnn. + + * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds + cc unit tests. There are many of them and building will take a few + hours. After cmake, build and execute the tests with `MSBuild + /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj ctest -C + RelWithDebInfo` + + * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python kernel tests. After building the python wheel, you need + to install the new wheel before running the tests. To execute the tests, + use `ctest -C RelWithDebInfo` + + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python tests on serveral major packages. This option is only + valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. + After building the python wheel, you need to install the new wheel + before running the tests. To execute the tests, use `ctest -C + RelWithDebInfo` + + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + MKL support. If MKL is enabled you need to install the + [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). CMake + will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. + Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for + Deep Neural Networks (Intel(R) + MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add + `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + +4. Invoke MSBuild to build TensorFlow. + + Set up the path to find MSbuild: `D:\temp> "C:\Program Files (x86)\Microsoft + Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"` + + To build the C++ example program, which will be created as a `.exe` + executable in the subdirectory `.\Release`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj + D:\...\build> Release\tf_tutorials_example_trainer.exe + ``` + + To build the PIP package, which will be created as a `.whl` file in the + subdirectory `.\tf_python\dist`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj + ``` Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake deleted file mode 100644 index afadcc007d66414be3306e91e7186a00b6e587ce..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -include (ExternalProject) - -set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include) -set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) -set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0) -set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc) - -if (WIN32) - set(jemalloc_INCLUDE_DIRS - ${jemalloc_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat - ) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) - else() - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib) - endif() -else() - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a) -endif() - -ExternalProject_Add(jemalloc - PREFIX jemalloc - URL ${jemalloc_URL} - URL_HASH ${jemalloc_HASH} - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 - BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc - INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." - CMAKE_CACHE_ARGS - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -Dwith-jemalloc-prefix:STRING=jemalloc_ - -Dwithout-export:BOOL=ON -) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index ad2af01bc002555ce48f8b9bfb7d8d724a1a7dc8..1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== include (ExternalProject) +include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) @@ -35,7 +36,7 @@ if(WIN32) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a) endif() set(png_HEADERS diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index f56fb35a0f71250f00b84e5cf94a24682bda6c82..56a57a2340ddc7f923c611c222a0399e279ad58a 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG v3.6.0) +set(PROTOBUF_TAG v3.6.1) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/make.bat b/tensorflow/contrib/cmake/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..d52b24e01d6590180106ba6cee2782c99d734ee3 --- /dev/null +++ b/tensorflow/contrib/cmake/make.bat @@ -0,0 +1,38 @@ +%echo off + +cd /d %~dp0 + +if exist _build rd /s /q _build + +mkdir _build +chdir _build + + +rem cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install + +CALL :NORMALIZEPATH "..\..\..\.." +SET SOURCE_DIR=%RETVAL% + +echo %SOURCE_DIR% + +SET SOURCE_DIR=F:\frameworks\tensorflow\ + +CALL :NORMALIZEPATH "../../../tools/git/gen_git_source.py" +SET SOURCE_PYTHON_SCRIPT=%RETVAL% + +CALL :NORMALIZEPATH "../../../core/util/version_info.cc" +SET SOURCE_VERSION_CC=%RETVAL% + +python %SOURCE_PYTHON_SCRIPT% --raw_generate %SOURCE_VERSION_CC% --source_dir %SOURCE_DIR% --git_tag_override= + +cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install + +EXIT /B + +:NORMALIZEPATH + SET RETVAL=%~dpfn1 + EXIT /B + + + + \ No newline at end of file diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fb871acae9963978485afef52dbba089aea4fd40..6e72670142d560a364350bb4769f1153f884b0f6 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -132,10 +132,8 @@ tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops tensorflow/contrib/data -tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests -tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto @@ -207,6 +205,8 @@ tensorflow/contrib/integrate/python tensorflow/contrib/integrate/python/ops tensorflow/contrib/kafka/python tensorflow/contrib/kafka/python/ops +tensorflow/contrib/ignite/python +tensorflow/contrib/ignite/python/ops tensorflow/contrib/keras tensorflow/contrib/keras/api tensorflow/contrib/keras/api/keras @@ -273,9 +273,6 @@ tensorflow/contrib/libsvm tensorflow/contrib/libsvm/python tensorflow/contrib/libsvm/python/kernel_tests tensorflow/contrib/libsvm/python/ops -tensorflow/contrib/linalg -tensorflow/contrib/linalg/python -tensorflow/contrib/linalg/python/ops tensorflow/contrib/linear_optimizer tensorflow/contrib/linear_optimizer/kernels tensorflow/contrib/linear_optimizer/kernels/g3doc @@ -409,7 +406,6 @@ tensorflow/contrib/summary tensorflow/contrib/tensorboard tensorflow/contrib/tensorboard/plugins tensorflow/contrib/tensorboard/plugins/projector -tensorflow/contrib/tensorboard/plugins/trace # TODO(sami): Add cmake implementations. # tensorflow/contrib/tensorrt/python # tensorflow/contrib/tensorrt/python/ops diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index cf1ee2ad76f2cc9f58dbe90182a3e17f1edc7ed3..42afbd9105ef3789430606d909979ca308e2eaa8 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -12,7 +12,6 @@ tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto tensorflow/contrib/tensorboard/plugins/projector -tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/profiler tensorflow/contrib/training/python/training diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 067c299a71cd4ac96878bcf27b4453466785e4ba..7e806685b8448cbd629985cdc00ed1193857abe6 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -258,14 +258,21 @@ add_dependencies(tf_core_lib ${tensorflow_EXTERNAL_DEPENDENCIES} tf_protos_cc) # force_rebuild always runs forcing ${VERSION_INFO_CC} target to run # ${VERSION_INFO_CC} would cache, but it depends on a phony never produced # target. -set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) -add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC}) -add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) -add_custom_command(OUTPUT - ${VERSION_INFO_CC} - COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py - ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} - DEPENDS __force_rebuild) +# This code forces rebuild every time, not needed as version from git is fetched only once +# move to make.bat which mimicks make.sh + +if (NOT WIN32) + + set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) + add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC}) + add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) + add_custom_command(OUTPUT + ${VERSION_INFO_CC} + COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py + ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} + DEPENDS __force_rebuild) +endif() + set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) ######################################################## diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 2c878c17167c662d10a8c7dabf41687efdbf65d8..ed31351d9eae3fad2a58caf2c80bfc691648adb8 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_src_py} "${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py" diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py index d5e14e7a641b5673e97882daf2b5a1796ee1bbef..f5431ca1ffd0a6ed16e32dd89007ca28ab54f5db 100644 --- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py +++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py @@ -45,7 +45,7 @@ class CoderOpsTest(test.TestCase): decoded = coder_ops.range_decode( encoded, array_ops.shape(data), cdf, precision=14) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual(*sess.run((data, decoded))) diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index d7583be6d8ed996ac894d3a8601f716cc27bdd86..f83386b8a4246ff2d7acdd2190804296582ee945 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -5,7 +5,10 @@ package(default_visibility = [":friends"]) package_group( name = "friends", includes = ["//tensorflow/compiler/jit:friends"], - packages = ["//tensorflow/..."], + packages = [ + "//tensorflow/...", + "//third_party/py/tensor2tensor/...", + ], ) load("//tensorflow:tensorflow.bzl", "tf_py_test") @@ -53,12 +56,16 @@ py_library( srcs = ["xla.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/compiler/jit:xla_ops_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:summary_op_util", "//tensorflow/python:util", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", ], ) diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 42b3b9f026c425ebe96c07edae67ddaad65bba87..3e631b59094b182384c27031b3c780abd57a1bcb 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,7 +173,7 @@ class JITTest(test.TestCase): class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 60f5af166234ba69e21a4a64cd3b3c102f66aef4..873b03580d6f1d9cb25c79cb31989d43cdb8c9a7 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -12,20 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""xla provides experimental xla support API.""" +"""xla is an experimental library that provides XLA support APIs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import contextlib from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.compiler.jit.ops import xla_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_op_util +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import function_utils +from tensorflow.python.util import tf_decorator _XLA_COMPILE_ATTR = '_xla_compile_id' _MAX_WARNING_LINES = 5 @@ -51,6 +60,30 @@ _UNSUPPORTED_OPS = set([ ]) +def compile(computation, inputs=None): # pylint: disable=redefined-builtin + """Builds an operator that compiles and runs `computation` with XLA. + + Args: + computation: A Python function that builds a computation to apply to the + input. If the function takes n inputs, 'inputs' should be a list of n + tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `compile` is a list of tensors corresponding to the tensors from the + output of `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + + Returns: + A list of output tensors. + """ + # pylint: disable=protected-access + return _compile_internal(computation, inputs) + + class XLACompileContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside an XLA computation cluster. @@ -206,3 +239,409 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if self.GetWhileContext(): return self.GetWhileContext().back_prop return False + + +def _compile_internal(computation, inputs=None): + """Builds graph operators that compiles and symbolically executes computation. + + Args: + computation: A Python function that builds the computation to compile and + execute. + inputs: A list of input tensors or `None` (equivalent to `[]`). Its order + should match ordering of computation arguments. + Returns: + A list of output tensors from computation. + Raises: + ValueError: If any element in computation outputs is neither an operations + or a value that can be converted to tensor. + TypeError: If `inputs` is not a list or tuple. + """ + if inputs is None: + inputs = [] + + if not isinstance(inputs, collections.Sequence): + raise TypeError('inputs must be a list') + + # Converts inputs to Tensors. + inputs = [ops.convert_to_tensor(x) for x in inputs] + input_arity = len(inputs) + + arg_error = tpu_function.check_function_argument_count( + computation, input_arity, infeed_queue=None) + if arg_error is not None: + raise TypeError( + 'Supplied computation cannot be called with the specified inputs. You ' + 'specified %d inputs: %s, but the computation needs %s' % + (input_arity, str([i.name for i in inputs[0]]), arg_error)) + + cluster_name = ops.get_default_graph().unique_name('cluster') + pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') + context = XLACompileContext(name=cluster_name, pivot=pivot) + try: + context.Enter() + + # Add identity ops so even unused inputs are 'consumed' by the + # computation. + computation_inputs = [ + array_ops.identity(x, name='input_{}'.format(i)) + for i, x in enumerate(inputs) + ] + + # Only resource variables work inside an XLA computation, so turn on + # resource variables for the computation. + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + vscope.set_use_resource(True) + + with _disable_summary_context(): + outputs = computation(*computation_inputs) + + # Restore variable scope after computation. + vscope.set_use_resource(saved_use_resource) + + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + output_arity = len(output_tensors) + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + output_tensors = new_output_tensors + context.ExitResult(output_tensors) + finally: + context.report_unsupported_operations() + context.Exit() + + outputs = [ + xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i)) + for i in xrange(output_arity) + ] + + with ops.control_dependencies(output_operations): + if output_arity == 0: + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + return control_flow_ops.no_op(name='output_0') + else: + # Wraps the outputs in identity operators that carries control + # dependencies. + return [ + array_ops.identity(outputs[i], name='output_%d' % i) + for i in xrange(output_arity) + ] + + +@contextlib.contextmanager +def _disable_summary_context(): + """Enters a context where all summary ops are skipped. + + Summaries are not yet supported in xla.compile(). So we provide this context + manager that can skip creating summary ops. This is a temporary workaround due + to XLA not supporting summary ops. + + Yields: + None. + """ + original_skip_summary_func = summary_op_util.skip_summary + summary_op_util.skip_summary = lambda: True + + try: + yield + finally: + summary_op_util.skip_summary = original_skip_summary_func + + +class _CapturedObject(object): + """A placeholder to capture an object.""" + + def __init__(self): + self._object = None + + def capture(self, o): + if self._object: + raise RuntimeError( + 'InternalError: _CapturedObject can capture only once. Please file ' + 'bug.') + + self._object = o + + def get(self): + return self._object + + +def _get_scaffold(captured_scaffold_fn): + """Retrieves the Scaffold from `captured_scaffold_fn`.""" + scaffold_fn = captured_scaffold_fn.get() + + if not scaffold_fn: + return None + + scaffold = scaffold_fn() + if scaffold is None: + raise ValueError( + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') + + return scaffold + + +class _ModelFnWrapper(object): + """_ModelFnWrapper supports executing model_fn with XLA.""" + + def __init__(self, function): + self._model_fn = function + + def __call__(self, features, labels, mode, params): + + # TPUEstimator compiles model_fn when use_tpu=True. To avoid double + # compilation, we use this params['use_tpu'] as a hint. When it is set to + # True, model_fn is called without compilation. + # Note that this condition isn't accurate for the case of exporting a model. + # In that case we should ideally not compile so that user can see detailed + # graph. However, we don't have enough information to tell whether model_fn + # is being called for export mode or not. + # TODO(ycao): Make this condition more accurate when implementing PREDICT + # mode. + if params.get('use_tpu'): + return self._call_model_fn(features, labels, mode, params) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_step, captured_scaffold_fn = self._make_train_step( + features, labels, params) + (loss,) = compile(train_step) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + train_op=array_ops.identity(loss), + scaffold=_get_scaffold(captured_scaffold_fn)) + elif mode == model_fn_lib.ModeKeys.EVAL: + eval_step, captured_eval_metric_fn, captured_scaffold_fn = ( + self._make_eval_step(features, labels, params)) + outputs = compile(eval_step) + loss = outputs[0] + + # Calculate eval_metric_ops if eval_metric_fn is set and captured. + eval_metric_fn = captured_eval_metric_fn.get() + if eval_metric_fn: + eval_metric_fn_tensors = outputs[1:] + eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors) + else: + eval_metric_ops = None + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=eval_metric_ops, + scaffold=_get_scaffold(captured_scaffold_fn)) + else: + raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are' + ' supported' % mode) + + def _make_train_step(self, features, labels, params): + """Creates a single step of training for xla.compile().""" + captured_scaffold_fn = _CapturedObject() + + def train_step(): + """A single step of training.""" + estimator_spec = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, params) + + try: + captured_scaffold_fn.capture(estimator_spec.scaffold_fn) + except AttributeError: + captured_scaffold_fn.capture(None) + + # train_step will be run by xla.compile(). xla.compile() only supports + # tensor output while train_op can be either an operation or a tensor. + # Even though xla.compile() automatically adds operation-typed train_op as + # control dependency of other tensor outputs, it doesn't do so for + # tensor-typed train_op. Thus, we need to set it explicitly here. + with ops.control_dependencies([estimator_spec.train_op]): + return array_ops.identity(estimator_spec.loss) + + return train_step, captured_scaffold_fn + + def _make_eval_step(self, features, labels, params): + """Creates a single step of evaluation for xla.compile().""" + captured_eval_metric_fn = _CapturedObject() + captured_scaffold_fn = _CapturedObject() + + def eval_step(): + """A single step of evaluation.""" + estimator_spec = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.EVAL, params) + + try: + captured_scaffold_fn.capture(estimator_spec.scaffold_fn) + except AttributeError: + captured_scaffold_fn.capture(None) + + eval_metric_fn = None + eval_metric_fn_tensors = [] + try: + if estimator_spec.eval_metrics: + (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics + except AttributeError: + pass + + # If a dictionary is provided, we need to convert it into a list sorted + # according to order of eval_metric_fn positional arguments. + if isinstance(eval_metric_fn_tensors, dict): + eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) + eval_metric_fn_tensors = [ + eval_metric_fn_tensors[i] for i in eval_metric_fn_args + ] + + captured_eval_metric_fn.capture(eval_metric_fn) + + return tuple([estimator_spec.loss] + eval_metric_fn_tensors) + + return eval_step, captured_eval_metric_fn, captured_scaffold_fn + + def _call_model_fn(self, features, labels, mode, params): + """Calls the model_fn with required parameters.""" + model_fn_args = function_utils.fn_args(self._model_fn) + kwargs = {} + + if 'labels' in model_fn_args: + kwargs['labels'] = labels + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') + if 'mode' in model_fn_args: + kwargs['mode'] = mode + + if 'params' in model_fn_args: + kwargs['params'] = params + + return self._verify_estimator_spec( + self._model_fn(features=features, **kwargs)) + + def _verify_estimator_spec(self, estimator_spec): + """Verifies estimator spec contains correct data.""" + # TODO(ycao): Implement estimator spec verification for other modes. + + try: + if estimator_spec.scaffold: + logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation' + '. Please use TPUEstimatorSpec.scaffold_fn instead.') + except AttributeError: + pass + + try: + if estimator_spec.eval_metric_ops: + raise ValueError('EstimatorSpec.eval_metric_ops is not supported with ' + 'XLA compilation. Please use ' + 'TPUEstimatorSpec.eval_metrics instead.') + except AttributeError: + pass + + if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL: + # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics, + # check that eval_metrics contains eval_metric_fn and + # eval_metric_fn_tensors with matching arguments. + try: + eval_metrics = estimator_spec.eval_metrics + except AttributeError: + eval_metrics = None + + if eval_metrics: + (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics + eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) + + if isinstance(eval_metric_fn_tensors, dict): + missing_tensors = [ + i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors + ] + additional_tensors = [ + i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args + ] + + if missing_tensors: + raise ValueError('Arguments %s are needed by metric_fn (first ' + 'element of TPUEstimatorSpec.eval_metrics) but ' + 'they are not provided by evaluation tensors ' + '(second element of TPUEstimatorSpec.eval_metrics)' + '.' % missing_tensors) + + if additional_tensors: + raise ValueError('Arguments %s are provided by evaluation tensors ' + '(second element of TPUEstimatorSpec.eval_metrics)' + ' but they are not needed by metric_fn (first ' + 'element of TPUEstimatorSpec.eval_metrics).' % + additional_tensors) + + return estimator_spec + + +def estimator_model_fn(target_model_fn=None): + """estimator_model_fn decorates a model_fn to be compiled for execution. + + Currently only it only works with `TPUEstimator`. If you need to use it with + base `Estimator`, please add `tf.enable_resource_variables()` at beginning of + your program. + + Example 1, decorating model_fn: + ``` + @xla.estimator_model_fn() + def model_fn(features, labels, mode, params): + ... + return EstimatorSpec(...) + + + est = Estimator(model_fn=model_fn, ...) + est.train(...) + + ``` + + Example 2, decorator as function: + ``` + def model_fn(features, labels, mode, params): + ... + return EstimatorSpec(...) + + est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...) + est.train(...) + ``` + + Args: + target_model_fn: model_fn to be decorated. This is only needed when + decorator is used in function call form (example 2). + + Returns: + Decorated target_model_fn. + """ + + def decorated(function): + return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) + + return decorated(target_model_fn) if target_model_fn else decorated diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py index d1af15f7e423c5135071ea73f6b7a0709d140600..67f8ac2b9322f39b02c521f8b9cde3831c7889b8 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -102,9 +102,9 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): 0.0, (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive))) - multipliers += scale * inactive + multipliers = multipliers + (scale * inactive) new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype) - multipliers *= new_inactive + multipliers = multipliers * new_inactive return (iteration, multipliers, new_inactive, inactive) iteration = standard_ops.constant(0) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 2c673d9347141b3a12eb9ec76065d22f1769ac12..a6cb1f62f059770c90bd1aeea391d841aed9aacf 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -175,9 +175,9 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): scale = (1.0 - standard_ops.reduce_sum( matrix, axis=0, keepdims=True)) / standard_ops.maximum( 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) - matrix += scale * inactive + matrix = matrix + (scale * inactive) new_inactive = standard_ops.cast(matrix > 0, matrix.dtype) - matrix *= new_inactive + matrix = matrix * new_inactive return (iteration, matrix, new_inactive, inactive) iteration = standard_ops.constant(0) @@ -210,8 +210,9 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): # For numerical reasons, make sure that the largest matrix element is zero # before exponentiating. - log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True) - log_matrix -= standard_ops.log( + log_matrix = log_matrix - standard_ops.reduce_max( + log_matrix, axis=0, keepdims=True) + log_matrix = log_matrix - standard_ops.log( standard_ops.reduce_sum( standard_ops.exp(log_matrix), axis=0, keepdims=True)) return log_matrix diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 6c9ab6aeb87fd39b22ab4f28d69b432b15899a13..9c5871da343ab1a5f11aeb674a20ab83f2eb1fbf 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -31,7 +31,7 @@ from __future__ import division from __future__ import print_function from copy import deepcopy -from tensorflow.python.ops.variables import Variable +from tensorflow.python.ops.variables import VariableV1 from tensorflow.python.client.session import Session from tensorflow.python.framework import ops @@ -55,7 +55,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''): TypeError: If `org_instance` is not a `Variable`. """ - if not isinstance(org_instance, Variable): + if not isinstance(org_instance, VariableV1): raise TypeError(str(org_instance) + ' is not a Variable') #The name of the new variable @@ -88,7 +88,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''): #Initialize the new variable with to_graph.as_default(): - new_var = Variable( + new_var = VariableV1( init_value, trainable, name=new_name, diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index 05744bec4e05405c04b5ec442e72e4495737ab5b..4d8651a79fde9b876d4fdd9b050e71d2eb7c893d 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -26,30 +26,33 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -graph1 = ops.Graph() -graph2 = ops.Graph() - class CopyVariablesTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testVariableCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Define a Variable in graph1 - some_var = variables.Variable(2) + some_var = variables.VariableV1(2) #Initialize session sess1 = session_lib.Session() #Initialize the Variable variables.global_variables_initializer().run(session=sess1) #Make a copy of some_var in the defsult scope in graph2 - copy1 = copy_elements.copy_variable_to_graph(some_var, graph2) + copy1 = copy_elements.copy_variable_to_graph(some_var, self.graph2) #Make another copy with different scope - copy2 = copy_elements.copy_variable_to_graph(some_var, graph2, "test_scope") + copy2 = copy_elements.copy_variable_to_graph(some_var, + self.graph2, + "test_scope") #Initialize both the copies - with graph2.as_default(): + with self.graph2.as_default(): #Initialize Session sess2 = session_lib.Session() #Initialize the Variables @@ -67,12 +70,16 @@ class CopyVariablesTest(test.TestCase): class CopyOpsTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testOpsCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Initialize a basic expression y = ax + b x = array_ops.placeholder("float") - a = variables.Variable(3.0) + a = variables.VariableV1(3.0) b = constant_op.constant(4.0) ax = math_ops.multiply(x, a) y = math_ops.add(ax, b) @@ -82,21 +89,21 @@ class CopyOpsTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #First, initialize a as a Variable in graph2 - a1 = copy_elements.copy_variable_to_graph(a, graph2) + a1 = copy_elements.copy_variable_to_graph(a, self.graph2) #Initialize a1 in graph2 - with graph2.as_default(): + with self.graph2.as_default(): #Initialize session sess2 = session_lib.Session() #Initialize the Variable variables.global_variables_initializer().run(session=sess2) #Initialize a copy of y in graph2 - y1 = copy_elements.copy_op_to_graph(y, graph2, [a1]) + y1 = copy_elements.copy_op_to_graph(y, self.graph2, [a1]) #Now that y has been copied, x must be copied too. #Get that instance - x1 = copy_elements.get_copied_op(x, graph2) + x1 = copy_elements.get_copied_op(x, self.graph2) #Compare values of y & y1 for a sample input #and check if they match diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 2a91dcb63a80016e62d10d1310ca57e3e54434c5..43bb43129bfe1cb1c66f4965476f9b7f849658ad 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): log_norm) return log_norm - max_seq_len = array_ops.shape(inputs)[1] - return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), - true_fn=_single_seq_fn, - false_fn=_multi_seq_fn) + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or + array_ops.shape(inputs)[1], 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 5a667485beebe4bee7f051b5920920c72134987f..c59d3682d404e032d9f4bf81ef54ab456341cefa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -413,6 +413,31 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): self._testOneLSTMParamsSize(num_layers, num_units, input_size, direction) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLSTMParamsSizeShape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + constant_op.constant([4]), 200, 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, constant_op.constant([200]), 200, + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + model = _CreateModel( + cudnn_rnn_ops.CUDNN_LSTM, + 4, 200, constant_op.constant([200]), + direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + params_size = model.params_size() + class CudnnRNNTestInference(TensorFlowTestCase): diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index fda1b9f1b36eaad69377fb33df7e15a4e87b32b8..57793a8ff5e2ec49dfea42c08eb9456cb2875eab 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -460,7 +460,7 @@ class CudnnRNNTestBasic(test_util.TensorFlowTestCase): grad, = gradients.gradients( math_ops.reduce_sum(accumulation), (original_input,)) init_op = variables.global_variables_initializer() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) accumulation_eval, grad_eval = sess.run((accumulation, grad)) self.assertAllEqual([28, 100, 100], accumulation_eval.shape) diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 9f710613dd0d549d4f93bae8780427f7878234a6..38f1c65a4d5c33ab2558fa9277b512ab86e98959 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load( - "//tensorflow:tensorflow.bzl", - "tf_custom_op_library", - "tf_gen_op_libs", - "if_not_windows", -) -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "if_static", -) - py_library( name = "data", srcs = ["__init__.py"], @@ -25,30 +14,3 @@ py_library( "//tensorflow/python:util", ], ) - -cc_library( - name = "lib_proto_parsing_for_dataset_ops", - deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), -) - -tf_custom_op_library( - name = "_dataset_ops.so", - srcs = [ - "ops/dataset_ops.cc", - "ops/indexed_dataset_ops.cc", - ], - deps = [ - "//tensorflow/contrib/data/kernels:dataset_kernels", - "//tensorflow/contrib/data/kernels:indexed_dataset", - ] + if_static( - extra_deps = [":lib_proto_parsing_for_dataset_ops"], - otherwise = [], - ), -) - -tf_gen_op_libs( - op_lib_names = [ - "dataset_ops", - "indexed_dataset_ops", - ], -) diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 848782e8d89b8670caf3b45de4912a7e0855c102..90be7a66cac6746e29a121fe6a772a94e04e8f02 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -1,10 +1,12 @@ `tf.contrib.data` API ===================== -NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead. -We are continuing to support existing code using the `tf.contrib.data` APIs in -the current version of TensorFlow, but will eventually remove support. The -`tf.data` APIs are subject to backwards compatibility guarantees. +NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead, +or `tf.data.experimental` for the experimental transformations previously hosted +in this module. We are continuing to support existing code using the +`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually +remove support. The non-experimental `tf.data` APIs are subject to backwards +compatibility guarantees. Porting your code to `tf.data` ------------------------------ @@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of changes is as follows: * `dataset.dense_to_sparse_batch(...)` is now - `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`. + `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`. * `dataset.enumerate(...)` is now - `dataset.apply(tf.contrib.data.enumerate_dataset(...))`. + `dataset.apply(tf.data.experimental.enumerate_dataset(...))`. * `dataset.group_by_window(...)` is now - `dataset.apply(tf.contrib.data.group_by_window(...))`. + `dataset.apply(tf.data.experimental.group_by_window(...))`. * `dataset.ignore_errors()` is now - `dataset.apply(tf.contrib.data.ignore_errors())`. + `dataset.apply(tf.data.experimental.ignore_errors())`. * `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`. The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index baec238c62e5cd375e3e8d46039e8e5b21269a6f..c3d3e981fa10144ed94217cf948b485a7c2bced8 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -44,6 +44,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@group_by_reducer @@group_by_window @@ignore_errors +@@latency_stats @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator @@ -57,11 +58,15 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@reduce_dataset @@sample_from_datasets @@scan +@@set_stats_aggregator @@shuffle_and_repeat @@sliding_window_batch @@sloppy_interleave +@@StatsAggregator @@unbatch @@unique + +@@AUTOTUNE """ from __future__ import absolute_import @@ -107,12 +112,13 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter + +# Optimization constant that can be used to enable auto-tuning. +from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE + from tensorflow.python.data.ops.iterator_ops import get_next_as_optional from tensorflow.python.data.ops.optional_ops import Optional # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) - -# A constant that can be used to enable auto-tuning. -AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD deleted file mode 100644 index ec6cb37193cdfbc888df5dc6787854241daea621..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/kernels/BUILD +++ /dev/null @@ -1,139 +0,0 @@ -# Description: -# Contains kernels for datasets and iterators. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "indexed_dataset_headers", - hdrs = ["indexed_dataset.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], -) - -cc_library( - name = "indexed_dataset", - srcs = [ - "identity_indexed_dataset.cc", - "indexed_dataset.cc", - ], - deps = [ - ":indexed_dataset_headers", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "prefetching_kernels", - srcs = ["prefetching_kernels.cc"], - deps = [ - "//tensorflow/core:core_cpu_headers_lib", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "directed_interleave_dataset_op", - srcs = ["directed_interleave_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "csv_dataset_op", - srcs = ["csv_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "ignore_errors_dataset_op", - srcs = ["ignore_errors_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "lmdb_dataset_op", - srcs = ["lmdb_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@lmdb", - "@protobuf_archive//:protobuf_headers", - ], -) - -cc_library( - name = "threadpool_dataset_op", - srcs = ["threadpool_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "unique_dataset_op", - srcs = ["unique_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "assert_next_dataset_op", - srcs = ["assert_next_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "dataset_kernels", - deps = [ - ":assert_next_dataset_op", - ":csv_dataset_op", - ":directed_interleave_dataset_op", - ":ignore_errors_dataset_op", - ":indexed_dataset", - ":lmdb_dataset_op", - ":prefetching_kernels", - ":threadpool_dataset_op", - ":unique_dataset_op", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], -) diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc deleted file mode 100644 index ae104d55bd813fdbc9829ccbc274612a112c8e1d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ /dev/null @@ -1,278 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("DirectedInterleaveDataset") - .Input("selector_input_dataset: variant") - .Input("data_input_datasets: N * variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("N: int >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -A substitute for `InterleaveDataset` on a fixed list of `N` datasets. - -selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines - which of the `N` data inputs should produce the next output element. -data_input_datasets: `N` datasets with the same type that will be interleaved - according to the values of `selector_input_dataset`. -)doc"); - -REGISTER_OP("CSVDataset") - .Input("filenames: string") - .Input("compression_type: string") - .Input("buffer_size: int64") - .Input("header: bool") - .Input("field_delim: string") - .Input("use_quote_delim: bool") - .Input("na_value: string") - .Input("select_cols: int64") - .Input("record_defaults: output_types") - .Output("handle: variant") - .Attr("output_types: list({float,double,int32,int64,string}) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // `filenames` must be a scalar or a vector. - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); - // `compression_type`, `buffer_size`, `header`, `field_delim`, - // `use_quote_delim`, `na_value` must be scalars - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - // `select_cols` must be a vector - TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); - // `record_defaults` must be lists of scalars - for (size_t i = 8; i < c->num_inputs(); ++i) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); - } - return shape_inference::ScalarShape(c); - }); - -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - -REGISTER_OP("UniqueDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the unique elements of `input_dataset`. -)doc"); - -REGISTER_OP("IteratorGetDevice") - .Input("resource: resource") - .Output("device: string") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Returns the name of the device on which `resource` has been placed. -)doc"); - -REGISTER_OP("FunctionBufferingResource") - .Input("string_arg: string") - .Input("target_device: string") - .Output("resource: resource") - .Attr("shared_name: string") - .Attr("container: string") - .Attr("f: func") - .Attr("buffer_size: int") - .Attr("output_types: list(type)") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Creates a resource that fills up a buffer by making function calls. - -string_arg: String argument to the function call. -target_device: Target device to execute the function on. -resource: Handle to the resource created. -f: Function to be executed. -buffer_size: Size of the buffer. -container: If non-empty, this resource is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this resource will be shared under the given name - across multiple sessions. -output_types: The type list for the return values. -)doc"); - -REGISTER_OP("FunctionBufferingResourceGetNext") - .Input("function_buffer_resource: resource") - .Attr("output_types: list(type)") - .Output("output: output_types") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Gets the next element from a FunctionBufferingResource. - -function_buffer_resource: The FunctionBufferingResource handle. -output: A list of return values. -output_types: The type list for the return values. -)doc"); - -REGISTER_OP("FunctionBufferingResourceReset") - .Input("function_buffer_resource: resource") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Resets the FunctionBufferingResource. - -function_buffer_resource: The FunctionBufferingResource handle. -)doc"); - -REGISTER_OP("MultiDeviceIterator") - .Output("handle: resource") - .Attr("devices: list(string) >= 1") - .Attr("shared_name: string") - .Attr("container: string") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Doc(R"doc( -Creates a MultiDeviceIterator resource. - -handle: Handle to the resource created. -devices: A list of devices the iterator works across. -shared_name: If non-empty, this resource will be shared under the given name - across multiple sessions. -container: If non-empty, this resource is placed in the given container. - Otherwise, a default container is used. -output_types: The type list for the return values. -output_shapes: The list of shapes being produced. -)doc"); - -REGISTER_OP("MultiDeviceIteratorInit") - .Input("dataset: variant") - .Input("multi_device_iterator: resource") - .Input("max_buffer_size: int64") - .Output("incarnation_id: int64") - .Doc(R"doc( -Initializes the multi device iterator with the given dataset. -max_buffer_size: The maximum size of the host side per device buffer to keep. -incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator - is running. -dataset: Dataset to be iterated upon. -multi_device_iterator: A MultiDeviceIteratorResource. -)doc"); - -REGISTER_OP("MultiDeviceIteratorGetNextFromShard") - .Input("multi_device_iterator: resource") - .Input("shard_num: int32") - .Input("incarnation_id: int64") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Doc(R"doc( -Gets next element for the provided shard number. - -multi_device_iterator: A MultiDeviceIterator resource. -shard_num: Integer representing which shard to fetch data for. -incarnation_id: Which incarnation of the MultiDeviceIterator is running. -components: Result of the get_next on the dataset. -output_types: The type list for the return values. -output_shapes: The list of shapes being produced. -)doc"); - -REGISTER_OP("MultiDeviceIteratorToStringHandle") - .Input("multi_device_iterator: resource") - .Output("string_handle: string") - .Doc(R"doc( -Produces a string handle for the given MultiDeviceIterator. - -multi_device_iterator: A MultiDeviceIterator resource. -string_handle: A string representing the resource. -)doc"); - -REGISTER_OP("MultiDeviceIteratorFromStringHandle") - .Input("string_handle: string") - .Output("multi_device_iterator: resource") - .Attr("output_types: list(type) >= 0 = []") - .Attr("output_shapes: list(shape) >= 0 = []") - .Doc(R"doc( -Generates a MultiDeviceIterator resource from its provided string handle. - -string_handle: String representing the resource. -multi_device_iterator: A MultiDeviceIterator resource. -output_types: The type list for the return values. -output_shapes: The list of shapes being produced. -)doc"); - -REGISTER_OP("ThreadPoolDataset") - .Input("input_dataset: variant") - .Input("thread_pool: resource") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that uses a custom thread pool to compute `input_dataset`. - -handle: A resource produced by the ThreadPoolHandle op. -)doc"); - -REGISTER_OP("ThreadPoolHandle") - .Output("handle: resource") - .SetShapeFn(shape_inference::ScalarShape) - .Attr("num_threads: int") - .Attr("max_intra_op_parallelism: int = 1") - .Attr("display_name: string") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Doc(R"doc( -Creates a custom thread pool with the given number of threads. - -handle: A resource that can be consumed by one or more ThreadPoolDataset ops. -num_threads: The number of threads in the thread pool. -max_intra_op_parallelism: The maximum degree of parallelism to use within - operations that execute on this threadpool. -display_name: A human-readable name for the threads that may be visible in - some visualizations. -)doc"); - -REGISTER_OP("AssertNextDataset") - .Input("input_dataset: variant") - .Input("transformations: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // transformations should be a vector. - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); - return shape_inference::ScalarShape(c); - }); - -REGISTER_OP("LMDBDataset") - .Input("filenames: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc deleted file mode 100644 index cd9b7c68a04a33ca6dec1e9088c3606deebdb7f4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("IdentityIndexedDataset") - .Input("size: uint64") - .Output("handle: variant") - .SetIsStateful() - .SetShapeFn( - shape_inference::ScalarShape); // TODO(saeta): check input shapes. - -/////////////////////////////////////////////////////////////////////////////// -// IndexedDataset Internals -/////////////////////////////////////////////////////////////////////////////// - -// Creates the handle. -REGISTER_OP("MaterializedIndexDatasetHandle") - .Output("handle: resource") - .Attr("container: string") - .Attr("shared_name: string") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); - -// Actually materialize the materialize handle. -REGISTER_OP("IndexedDatasetMaterialize") - .Input("dataset: variant") - .Input("materialized: resource") - .SetShapeFn(shape_inference::NoOutputs); - -namespace { - -Status GetShapeFn(shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); -} - -} // namespace - -REGISTER_OP("IndexedDatasetGet") - .Input("materialized: resource") - .Input("index: uint64") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(GetShapeFn) - .Doc(R"doc( -Gets the element at `index` from `materialized` IndexedDataset. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1f947e97f9231f83d08b5f6b8c5765f82c6db6b3..42f538b4ba1cb5b6a9a717e94f4e688cae56b056 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -8,193 +8,25 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") py_test( - name = "batch_dataset_op_test", - size = "medium", - srcs = ["batch_dataset_op_test.py"], + name = "assert_element_shape_test", + srcs = ["assert_element_shape_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_oss", # (b/79552534) - "no_pip", - ], deps = [ "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "bucketing_test", - size = "medium", - srcs = ["bucketing_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) -py_test( - name = "csv_dataset_op_test", - size = "medium", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:readers", - "//third_party/py/numpy", - ], -) - -py_test( - name = "dataset_constructor_op_test", - size = "medium", - srcs = ["dataset_constructor_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "nomac", # b/62040583 - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -py_test( - name = "directed_interleave_dataset_test", - size = "medium", - srcs = ["directed_interleave_dataset_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:random_seed", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "get_single_element_test", - size = "small", - srcs = ["get_single_element_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/ops:dataset_ops", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "indexed_dataset_ops_test", - srcs = ["indexed_dataset_ops_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:contrib_op_loader", - "//tensorflow/contrib/data/python/ops:gen_dataset_ops", - "//tensorflow/contrib/data/python/ops:indexed_dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "interleave_dataset_op_test", - size = "medium", - srcs = ["interleave_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", - "no_pip", - "notap", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/ops:dataset_ops", - "@six_archive//:six", - ], -) - -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:estimator_py", - ], -) - py_test( name = "lmdb_dataset_op_test", size = "medium", @@ -216,252 +48,24 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//third_party/py/numpy", ], ) py_test( - name = "map_dataset_op_test", - size = "medium", - srcs = ["map_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "noasan", # times out - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "filter_dataset_op_test", - size = "medium", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "map_defun_op_test", + name = "reduce_dataset_test", size = "small", - srcs = ["map_defun_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["reduce_dataset_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:map_defun", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:session", - ], -) - -py_test( - name = "parsing_ops_test", - size = "small", - srcs = ["parsing_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:parsing_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/compat:compat", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], - tags = [ - "manual", - "no_oss", - "no_windows_gpu", - "notap", - ], -) - -py_test( - name = "range_dataset_op_test", - size = "small", - srcs = ["range_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:counter", - "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_library( - name = "reader_dataset_ops_test_base", - testonly = 1, - srcs = [ - "reader_dataset_ops_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = [ - "//tensorflow/contrib/data/python/kernel_tests:__pkg__", - "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -py_test( - name = "reader_dataset_ops_test", - size = "medium", - srcs = ["reader_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":reader_dataset_ops_test_base", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", - ], -) - -py_test( - name = "resample_test", - size = "medium", - srcs = ["resample_test.py"], - shard_count = 2, - srcs_version = "PY2AND3", - tags = [ - "noasan", - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:resampling", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", - "@six_archive//:six", - ], -) - -py_test( - name = "scan_dataset_op_test", - size = "small", - srcs = ["scan_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:scan_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - "//third_party/py/numpy", - ], -) - -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:shuffle_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) @@ -477,151 +81,9 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) - -py_library( - name = "sql_dataset_op_test_base", - srcs = ["sql_dataset_op_test_base.py"], - srcs_version = "PY2AND3", - visibility = [ - "//tensorflow/contrib/data/python/kernel_tests:__pkg__", - "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "@org_sqlite//:python", - ], -) - -py_test( - name = "sql_dataset_op_test", - size = "small", - srcs = ["sql_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sql_dataset_op_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - ], -) - -py_test( - name = "stats_dataset_ops_test", - size = "medium", - srcs = ["stats_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":reader_dataset_ops_test_base", - ":stats_dataset_test_base", - "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "stats_dataset_test_base", - srcs = ["stats_dataset_test_base.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "threadpool_dataset_ops_test", - size = "small", - srcs = ["threadpool_dataset_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:threadpool", - "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:script_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "unique_dataset_op_test", - size = "small", - srcs = ["unique_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_test( - name = "window_dataset_op_test", - size = "medium", - srcs = ["window_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:grouping", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "writer_ops_test", - size = "small", - srcs = ["writer_ops_test.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:writers", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -py_library( - name = "test_utils", - srcs = ["test_utils.py"], - deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/util:nest", - ], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0456463a1928cf226010670b90a5d574579e0411 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -0,0 +1,226 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class AssertElementShapeTest(test_base.DatasetTestBase): + + def test_assert_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(expected_shapes, dataset.output_shapes) + + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def test_assert_partial_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + partial_expected_shape = ( + tensor_shape.TensorShape(None), # Unknown shape + tensor_shape.TensorShape((None, 4))) # Partial shape + result = dataset.apply( + batching.assert_element_shape(partial_expected_shape)) + # Partial shapes are merged with actual shapes: + actual_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(actual_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_partial_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.cached_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py deleted file mode 100644 index 8e368bf2bc5060e1655dd24b1d285b0ee80e094d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ /dev/null @@ -1,990 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import time - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.python.client import session -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test -from tensorflow.python.util import compat - - -class BatchDatasetTest(test.TestCase, parameterized.TestCase): - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testDenseToSparseBatchDataset(self): - components = np.random.randint(12, size=(100,)).astype(np.int32) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - for start in range(0, len(components), 4): - results = sess.run(get_next) - self.assertAllEqual([[i, j] - for i, c in enumerate(components[start:start + 4]) - for j in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start + 4] for _ in range(c)], - results.values) - self.assertAllEqual([min(4, - len(components) - start), 12], - results.dense_shape) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testDenseToSparseBatchDatasetWithUnknownShape(self): - components = np.random.randint(5, size=(40,)).astype(np.int32) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).apply( - batching.dense_to_sparse_batch( - 4, [5, None])).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - for start in range(0, len(components), 4): - results = sess.run(get_next) - self.assertAllEqual([[i, j, z] - for i, c in enumerate(components[start:start + 4]) - for j in range(c) - for z in range(c)], results.indices) - self.assertAllEqual([ - c - for c in components[start:start + 4] for _ in range(c) - for _ in range(c) - ], results.values) - self.assertAllEqual([ - min(4, - len(components) - start), 5, - np.max(components[start:start + 4]) - ], results.dense_shape) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testDenseToSparseBatchDatasetWithInvalidShape(self): - input_tensor = array_ops.constant([[1]]) - with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): - dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator() - - def testDenseToSparseBatchDatasetShapeErrors(self): - input_tensor = array_ops.placeholder(dtypes.int32) - iterator = ( - dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - # Initialize with an input tensor of incompatible rank. - sess.run(init_op, feed_dict={input_tensor: [[1]]}) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "incompatible with the row shape"): - sess.run(get_next) - - # Initialize with an input tensor that is larger than `row_shape`. - sess.run(init_op, feed_dict={input_tensor: range(13)}) - with self.assertRaisesRegexp(errors.DataLossError, - "larger than the row shape"): - sess.run(get_next) - - def testUnbatchScalarDataset(self): - data = tuple([math_ops.range(10) for _ in range(3)]) - data = dataset_ops.Dataset.from_tensor_slices(data) - expected_types = (dtypes.int32,) * 3 - data = data.batch(2) - self.assertEqual(expected_types, data.output_types) - data = data.apply(batching.unbatch()) - self.assertEqual(expected_types, data.output_types) - - iterator = data.make_one_shot_iterator() - op = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual((i,) * 3, sess.run(op)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(op) - - def testUnbatchDatasetWithStrings(self): - data = tuple([math_ops.range(10) for _ in range(3)]) - data = dataset_ops.Dataset.from_tensor_slices(data) - data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z)) - expected_types = (dtypes.int32, dtypes.string, dtypes.int32) - data = data.batch(2) - self.assertEqual(expected_types, data.output_types) - data = data.apply(batching.unbatch()) - self.assertEqual(expected_types, data.output_types) - - iterator = data.make_one_shot_iterator() - op = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(op) - - def testUnbatchDatasetWithSparseTensor(self): - st = sparse_tensor.SparseTensorValue( - indices=[[i, i] for i in range(10)], - values=list(range(10)), - dense_shape=[10, 10]) - data = dataset_ops.Dataset.from_tensors(st) - data = data.apply(batching.unbatch()) - data = data.batch(5) - data = data.apply(batching.unbatch()) - iterator = data.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - st_row = sess.run(next_element) - self.assertEqual([i], st_row.indices) - self.assertEqual([i], st_row.values) - self.assertEqual([10], st_row.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testUnbatchDatasetWithDenseAndSparseTensor(self): - st = sparse_tensor.SparseTensorValue( - indices=[[i, i] for i in range(10)], - values=list(range(10)), - dense_shape=[10, 10]) - data = dataset_ops.Dataset.from_tensors((list(range(10)), st)) - data = data.apply(batching.unbatch()) - data = data.batch(5) - data = data.apply(batching.unbatch()) - iterator = data.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - dense_elem, st_row = sess.run(next_element) - self.assertEqual(i, dense_elem) - self.assertEqual([i], st_row.indices) - self.assertEqual([i], st_row.values) - self.assertEqual([10], st_row.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testUnbatchSingleElementTupleDataset(self): - data = tuple([(math_ops.range(10),) for _ in range(3)]) - data = dataset_ops.Dataset.from_tensor_slices(data) - expected_types = ((dtypes.int32,),) * 3 - data = data.batch(2) - self.assertEqual(expected_types, data.output_types) - data = data.apply(batching.unbatch()) - self.assertEqual(expected_types, data.output_types) - - iterator = data.make_one_shot_iterator() - op = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(((i,),) * 3, sess.run(op)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(op) - - def testUnbatchMultiElementTupleDataset(self): - data = tuple([(math_ops.range(10 * i, 10 * i + 10), - array_ops.fill([10], "hi")) for i in range(3)]) - data = dataset_ops.Dataset.from_tensor_slices(data) - expected_types = ((dtypes.int32, dtypes.string),) * 3 - data = data.batch(2) - self.assertAllEqual(expected_types, data.output_types) - data = data.apply(batching.unbatch()) - self.assertAllEqual(expected_types, data.output_types) - - iterator = data.make_one_shot_iterator() - op = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), - sess.run(op)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(op) - - def testUnbatchEmpty(self): - data = dataset_ops.Dataset.from_tensors( - (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), - constant_op.constant([], shape=[0, 4, 0]))) - data = data.apply(batching.unbatch()) - iterator = data.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testUnbatchStaticShapeMismatch(self): - data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), - np.arange(9))) - with self.assertRaises(ValueError): - data.apply(batching.unbatch()) - - def testUnbatchDynamicShapeMismatch(self): - ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) - ph2 = array_ops.placeholder(dtypes.int32, shape=None) - data = dataset_ops.Dataset.from_tensors((ph1, ph2)) - data = data.apply(batching.unbatch()) - iterator = data.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - # Mismatch in the 0th dimension. - sess.run( - iterator.initializer, - feed_dict={ - ph1: np.arange(7).astype(np.int32), - ph2: np.arange(8).astype(np.int32) - }) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(next_element) - - # No 0th dimension (i.e. scalar value) for one component. - sess.run( - iterator.initializer, - feed_dict={ - ph1: np.arange(7).astype(np.int32), - ph2: 7 - }) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(next_element) - - def testBatchAndDropRemainder(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size)) - .make_initializable_iterator()) - - next_element = iterator.get_next() - - with self.cached_session() as sess: - for test_batch_size in [1, 3, 7, 10]: - sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) - num_batches = 7 // test_batch_size - for i in range(num_batches): - result = sess.run(next_element) - for component, result_component in zip(components, result): - for j in range(test_batch_size): - self.assertAllEqual(component[(i * test_batch_size + j)], - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testBatchAndDropRemainderSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(12).map(_sparse).apply( - batching.batch_and_drop_remainder(5)).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - for i in range(2): - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], - values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], - dense_shape=[5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPaddedBatchAndDropRemainder(self): - els = [] - for length in [3, 6, 9, 4, 12, 10, 2]: - els.append((np.array(length), np.arange(length) + 1, - np.array(length * 2))) - - dataset = dataset_ops.Dataset.from_tensors(els[0]) - for el in els[1:]: - dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el)) - - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( - dataset.apply( - batching.padded_batch_and_drop_remainder( - batch_size, ([], [None], []))).make_initializable_iterator()) - - next_element = iterator.get_next() - - with self.cached_session() as sess: - for test_batch_size in [1, 3, 7, 10]: - sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) - num_batches = 7 // test_batch_size - for i in range(num_batches): - result = sess.run(next_element) - for component_idx, result_component in enumerate(result): - for j in range(test_batch_size): - data_idx = i * test_batch_size + j - comp = result_component[j] - unpadded = comp[comp > 0] - if np.isscalar(comp): - # The boolean mask indexing above adds a dim back. Rm it. - unpadded = unpadded[0] - self.assertAllEqual(els[data_idx][component_idx], unpadded) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPaddedBatchAndDropRemainderSparseError(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i - - with self.assertRaises(TypeError): - _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( - batching.padded_batch_and_drop_remainder(5)) - - def testBatchAndDropRemainderShapeInference(self): - components = (array_ops.placeholder(dtypes.int32), - (array_ops.placeholder(dtypes.int32, shape=[None]), - array_ops.placeholder(dtypes.int32, shape=[20, 30]))) - - # Test with a statically known batch size. - dataset = ( - dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(128))) - - self.assertIs(None, dataset.output_shapes[0].ndims) - self.assertEqual([128], dataset.output_shapes[1][0].as_list()) - self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list()) - - # Test with a dynamic batch size: the static shape will be unknown, because - # `batch_size` is a placeholder. - batch_size = array_ops.placeholder(dtypes.int64) - dataset = ( - dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size))) - - self.assertIs(None, dataset.output_shapes[0].ndims) - self.assertEqual([None], dataset.output_shapes[1][0].as_list()) - self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - - @parameterized.named_parameters( - ("Default", None, None), - ("SequentialCalls", 1, None), - ("ParallelCalls", 2, None), - ("ParallelBatches", None, 10), - ) - def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): - """Test a dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> - # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - count = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_calls=num_parallel_calls, - num_parallel_batches=num_parallel_batches)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([[None] + list(c.shape[1:]) for c in components], - [t.shape.as_list() for t in get_next]) - - with self.cached_session() as sess: - # Batch of a finite input, where the batch_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 28, batch_size: 14}) - num_batches = (28 * 7) // 14 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i * 14 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of a finite input, where the batch_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 14, batch_size: 8}) - - # We expect (num_batches - 1) full-sized batches. - num_batches = int(math.ceil((14 * 7) / 8)) - for i in range(num_batches - 1): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(8): - self.assertAllEqual(component[(i * 8 + j) % 7]**2, - result_component[j]) - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, batch_size: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Empty batch should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - - @parameterized.named_parameters( - ("Even", False), - ("Uneven", True), - ) - def testMapAndBatchPartialBatch(self, drop_remainder): - iterator = ( - dataset_ops.Dataset.range(10).apply( - batching.map_and_batch( - lambda x: array_ops.reshape(x * x, [1]), - batch_size=4, - drop_remainder=drop_remainder)).make_one_shot_iterator()) - if drop_remainder: - self.assertEqual([4, 1], iterator.output_shapes.as_list()) - else: - self.assertEqual([None, 1], iterator.output_shapes.as_list()) - next_element = iterator.get_next() - with self.cached_session() as sess: - self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) - self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) - if not drop_remainder: - self.assertAllEqual([[64], [81]], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testMapAndBatchYieldsPartialBatch(self): - iterator = (dataset_ops.Dataset.range(10) - .apply(batching.map_and_batch( - lambda x: array_ops.reshape(x * x, [1]), 4)) - .make_one_shot_iterator()) - self.assertEqual([None, 1], iterator.output_shapes.as_list()) - next_element = iterator.get_next() - with self.cached_session() as sess: - self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) - self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) - self.assertAllEqual([[64], [81]], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testMapAndBatchParallelGetNext(self): - iterator = (dataset_ops.Dataset.range(50000) - .apply(batching.map_and_batch(lambda x: x, batch_size=100)) - .make_one_shot_iterator()) - elements = [] - for _ in range(100): - elements.append(iterator.get_next()) - with self.cached_session() as sess: - for i in range(5): - got = sess.run(elements) - got.sort(key=lambda x: x[0]) - expected = [] - for j in range(100): - expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) - self.assertAllEqual(got, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elements) - - def testMapAndBatchParallelGetNextDropRemainder(self): - iterator = ( - dataset_ops.Dataset.range(49999).apply( - batching.map_and_batch( - lambda x: x, batch_size=100, drop_remainder=True)) - .make_one_shot_iterator()) - elements = [] - for _ in range(100): - elements.append(iterator.get_next()) - with self.cached_session() as sess: - for i in range(4): - got = sess.run(elements) - got.sort(key=lambda x: x[0]) - expected = [] - for j in range(100): - expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) - self.assertAllEqual(got, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elements) - - def testMapAndBatchSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(10).apply( - batching.map_and_batch(_sparse, 5)).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - for i in range(2): - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], - values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], - dense_shape=[5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMapAndBatchFails(self): - """Test a dataset that maps a TF function across its input elements.""" - dataset = dataset_ops.Dataset.from_tensors( - array_ops.check_numerics( - constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( - dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) - init_op = iterator.initializer - with self.cached_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(init_op, feed_dict={batch_size: 14}) - - def testMapAndBatchShapeMismatch(self): - """Test a dataset that maps a TF function across its input elements.""" - - def generator(): - yield [1] - yield [2] - yield [3] - yield [[4, 5, 6]] - - dataset = dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int32) - batch_size = 4 - iterator = ( - dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "number of elements does not match"): - sess.run(get_next) - - def testMapAndBatchImplicitDispose(self): - # Tests whether a map and batch dataset will be cleaned up correctly when - # the pipeline does not run it until exhaustion. - # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> - # MapAndBatchDataset(f=square_3, batch_size=100). - components = (np.arange(1000), - np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], - np.array(37.0) * np.arange(1000)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( - 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) - dataset = dataset.prefetch(5) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.cached_session() as sess: - for _ in range(3): - sess.run(get_next) - - @parameterized.named_parameters( - ("1", 0), - ("2", 5), - ("3", 10), - ("4", 90), - ("5", 95), - ("6", 99), - ) - def testMapAndBatchOutOfRangeError(self, threshold): - - def raising_py_fn(i): - if i >= threshold: - raise StopIteration() - else: - return i - - iterator = ( - dataset_ops.Dataset.range(100).apply( - batching.map_and_batch( - lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), - batch_size=10)).make_one_shot_iterator()) - get_next = iterator.get_next() - - with self.cached_session() as sess: - for i in range(threshold // 10): - self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) - if threshold % 10 != 0: - self.assertAllEqual( - [threshold // 10 * 10 + j for j in range(threshold % 10)], - sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - @parameterized.named_parameters( - ("1", False, dtypes.bool), - ("2", -42, dtypes.int8), - ("3", -42, dtypes.int16), - ("4", -42, dtypes.int32), - ("5", -42, dtypes.int64), - ("6", 42, dtypes.uint8), - ("7", 42, dtypes.uint16), - ("8", 42.0, dtypes.float16), - ("9", 42.0, dtypes.float32), - ("10", 42.0, dtypes.float64), - ("11", b"hello", dtypes.string), - ) - def testMapAndBatchTypes(self, element, dtype): - def gen(): - yield element - - dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( - batching.map_and_batch(lambda x: x, batch_size=10)) - - get_next = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - for _ in range(10): - self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) - - -class RestructuredDatasetTest(test.TestCase): - - def test_assert_element_shape(self): - - def create_dataset(_): - return (array_ops.ones(2, dtype=dtypes.float32), - array_ops.zeros((3, 4), dtype=dtypes.int32)) - - dataset = dataset_ops.Dataset.range(5).map(create_dataset) - expected_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 4))) - self.assertEqual(expected_shapes, dataset.output_shapes) - - result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) - - iterator = result.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for _ in range(5): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_assert_wrong_element_shape(self): - - def create_dataset(_): - return (array_ops.ones(2, dtype=dtypes.float32), - array_ops.zeros((3, 4), dtype=dtypes.int32)) - - dataset = dataset_ops.Dataset.range(3).map(create_dataset) - wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 10))) - with self.assertRaises(ValueError): - dataset.apply(batching.assert_element_shape(wrong_shapes)) - - def test_assert_element_shape_on_unknown_shape_dataset(self): - - def create_unknown_shape_dataset(x): - return script_ops.py_func( - lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) - - dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) - unknown_shapes = (tensor_shape.TensorShape(None), - tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) - - expected_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 4))) - result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) - - iterator = result.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op) - for _ in range(5): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): - - def create_unknown_shape_dataset(x): - return script_ops.py_func( - lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) - - dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) - unknown_shapes = (tensor_shape.TensorShape(None), - tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) - - wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - def test_assert_partial_element_shape(self): - - def create_dataset(_): - return (array_ops.ones(2, dtype=dtypes.float32), - array_ops.zeros((3, 4), dtype=dtypes.int32)) - - dataset = dataset_ops.Dataset.range(5).map(create_dataset) - partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape - tensor_shape.TensorShape((None, 4))) # Partial shape - result = dataset.apply( - batching.assert_element_shape(partial_expected_shape)) - # Partial shapes are merged with actual shapes: - actual_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 4))) - self.assertEqual(actual_shapes, result.output_shapes) - - iterator = result.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for _ in range(5): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_assert_wrong_partial_element_shape(self): - - def create_dataset(_): - return (array_ops.ones(2, dtype=dtypes.float32), - array_ops.zeros((3, 4), dtype=dtypes.int32)) - - dataset = dataset_ops.Dataset.range(3).map(create_dataset) - wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((None, 10))) - with self.assertRaises(ValueError): - dataset.apply(batching.assert_element_shape(wrong_shapes)) - - def test_assert_partial_element_shape_on_unknown_shape_dataset(self): - - def create_unknown_shape_dataset(x): - return script_ops.py_func( - lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) - - dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) - unknown_shapes = (tensor_shape.TensorShape(None), - tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) - - expected_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((None, 4))) - result = dataset.apply(batching.assert_element_shape(expected_shapes)) - self.assertEqual(expected_shapes, result.output_shapes) - - iterator = result.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - for _ in range(5): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self): - - def create_unknown_shape_dataset(x): - return script_ops.py_func( - lambda _: ( # pylint: disable=g-long-lambda - np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) - - dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) - unknown_shapes = (tensor_shape.TensorShape(None), - tensor_shape.TensorShape(None)) - self.assertEqual(unknown_shapes, dataset.output_shapes) - - wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - -class UnbatchDatasetBenchmark(test.Benchmark): - - def benchmarkNativeUnbatch(self): - batch_sizes = [1, 2, 5, 10, 20, 50] - elems_per_trial = 10000 - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) - batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - dataset = dataset.batch(batch_size_placeholder) - dataset = dataset.apply(batching.unbatch()) - dataset = dataset.skip(elems_per_trial) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for batch_size in batch_sizes: - deltas = [] - for _ in range(5): - sess.run( - iterator.initializer, - feed_dict={batch_size_placeholder: batch_size}) - start = time.time() - sess.run(next_element.op) - end = time.time() - deltas.append((end - start) / elems_per_trial) - - median_wall_time = np.median(deltas) - print("Unbatch (native) batch size: %d Median wall time per element:" - " %f microseconds" % (batch_size, median_wall_time * 1e6)) - self.report_benchmark( - iters=10000, - wall_time=median_wall_time, - name="benchmark_unbatch_dataset_native_batch_size_%d" % - batch_size) - - # Include a benchmark of the previous `unbatch()` implementation that uses - # a composition of more primitive ops. Eventually we'd hope to generate code - # that is as good in both cases. - def benchmarkOldUnbatchImplementation(self): - batch_sizes = [1, 2, 5, 10, 20, 50] - elems_per_trial = 10000 - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) - batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - dataset = dataset.batch(batch_size_placeholder) - dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices) - dataset = dataset.skip(elems_per_trial) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for batch_size in batch_sizes: - deltas = [] - for _ in range(5): - sess.run( - iterator.initializer, - feed_dict={batch_size_placeholder: batch_size}) - start = time.time() - sess.run(next_element.op) - end = time.time() - deltas.append((end - start) / elems_per_trial) - - median_wall_time = np.median(deltas) - print("Unbatch (unfused) batch size: %d Median wall time per element:" - " %f microseconds" % (batch_size, median_wall_time * 1e6)) - self.report_benchmark( - iters=10000, - wall_time=median_wall_time, - name="benchmark_unbatch_dataset_unfused_batch_size_%d" % - batch_size) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py deleted file mode 100644 index 94718bb477d411259f96f74bca27613575df5591..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ /dev/null @@ -1,783 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import random - -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class GroupByReducerTest(test.TestCase): - - def checkResults(self, dataset, shapes, values): - self.assertEqual(shapes, dataset.output_shapes) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - for expected in values: - got = sess.run(get_next) - self.assertEqual(got, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSum(self): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).apply( - grouping.group_by_reducer(lambda x: x % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) - - def testAverage(self): - - def reduce_fn(x, y): - return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( - x[1] + 1), x[1] + 1 - - reducer = grouping.Reducer( - init_func=lambda _: (0.0, 0.0), - reduce_func=reduce_fn, - finalize_func=lambda x, _: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).apply( - grouping.group_by_reducer( - lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) - - def testConcat(self): - components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) - reducer = grouping.Reducer( - init_func=lambda x: "", - reduce_func=lambda x, y: x + y[0], - finalize_func=lambda x: x) - for i in range(1, 11): - dataset = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensor_slices(components), - dataset_ops.Dataset.range(2 * i))).apply( - grouping.group_by_reducer(lambda x, y: y % 2, reducer)) - self.checkResults( - dataset, - shapes=tensor_shape.scalar(), - values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) - - def testSparseSum(self): - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1], dtype=np.int64)), - dense_shape=np.array([1, 1])) - - reducer = grouping.Reducer( - init_func=lambda _: _sparse(np.int64(0)), - reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), - finalize_func=lambda x: x.values[0]) - for i in range(1, 11): - dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( - grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) - self.checkResults( - dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) - - def testChangingStateShape(self): - - def reduce_fn(x, _): - # Statically known rank, but dynamic length. - larger_dim = array_ops.concat([x[0], x[0]], 0) - # Statically unknown rank. - larger_rank = array_ops.expand_dims(x[1], 0) - return larger_dim, larger_rank - - reducer = grouping.Reducer( - init_func=lambda x: ([0], 1), - reduce_func=reduce_fn, - finalize_func=lambda x, y: (x, y)) - - for i in range(1, 11): - dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( - grouping.group_by_reducer(lambda x: x, reducer)) - self.assertEqual([None], dataset.output_shapes[0].as_list()) - self.assertIs(None, dataset.output_shapes[1].ndims) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - with self.cached_session() as sess: - x, y = sess.run(get_next) - self.assertAllEqual([0] * (2**i), x) - self.assertAllEqual(np.array(1, ndmin=i), y) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testTypeMismatch(self): - reducer = grouping.Reducer( - init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), - reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - TypeError, - "The element types for the new state must match the initial state."): - dataset.apply( - grouping.group_by_reducer(lambda _: np.int64(0), reducer)) - - # TODO(b/78665031): Remove once non-scalar keys are supported. - def testInvalidKeyShape(self): - reducer = grouping.Reducer( - init_func=lambda x: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - ValueError, "`key_func` must return a single tf.int64 tensor."): - dataset.apply( - grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) - - # TODO(b/78665031): Remove once non-int64 keys are supported. - def testInvalidKeyType(self): - reducer = grouping.Reducer( - init_func=lambda x: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - - dataset = dataset_ops.Dataset.range(10) - with self.assertRaisesRegexp( - ValueError, "`key_func` must return a single tf.int64 tensor."): - dataset.apply( - grouping.group_by_reducer(lambda _: "wrong", reducer)) - - def testTuple(self): - def init_fn(_): - return np.array([], dtype=np.int64), np.int64(0) - - def reduce_fn(state, value): - s1, s2 = state - v1, v2 = value - return array_ops.concat([s1, [v1]], 0), s2 + v2 - - def finalize_fn(s1, s2): - return s1, s2 - - reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) - dataset = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( - grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - x, y = sess.run(get_next) - self.assertAllEqual(x, np.asarray([x for x in range(10)])) - self.assertEqual(y, 45) - - -class GroupByWindowTest(test.TestCase): - - def testSimple(self): - components = np.random.randint(100, size=(200,)).astype(np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) - .apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - counts = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - result = sess.run(get_next) - self.assertTrue( - all(x % 2 == 0 - for x in result) or all(x % 2 == 1) - for x in result) - counts.append(result.shape[0]) - - self.assertEqual(len(components), sum(counts)) - num_full_batches = len([c for c in counts if c == 4]) - self.assertGreaterEqual(num_full_batches, 24) - self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) - - def testImmediateOutput(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - # The input is infinite, so this test demonstrates that: - # 1. We produce output without having to consume the entire input, - # 2. Different buckets can produce output at different rates, and - # 3. For deterministic input, the output is deterministic. - for _ in range(3): - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) - self.assertAllEqual([2, 2, 2, 2], sess.run(get_next)) - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - - def testSmallGroups(self): - components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) - # The small outputs at the end are deterministically produced in key - # order. - self.assertAllEqual([0, 0, 0], sess.run(get_next)) - self.assertAllEqual([1], sess.run(get_next)) - - def testEmpty(self): - iterator = ( - dataset_ops.Dataset.range(4).apply( - grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Window size must be greater than zero, but got 0."): - print(sess.run(get_next)) - - def testReduceFuncError(self): - components = np.random.randint(100, size=(200,)).astype(np.int64) - - def reduce_func(_, xs): - # Introduce an incorrect padded shape that cannot (currently) be - # detected at graph construction time. - return xs.padded_batch( - 4, - padded_shapes=(tensor_shape.TensorShape([]), - constant_op.constant([5], dtype=dtypes.int64) * -1)) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( - grouping.group_by_window(lambda x, _: x % 2, reduce_func, - 32)).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - def testConsumeWindowDatasetMoreThanOnce(self): - components = np.random.randint(50, size=(200,)).astype(np.int64) - - def reduce_func(key, window): - # Apply two different kinds of padding to the input: tight - # padding, and quantized (to a multiple of 10) padding. - return dataset_ops.Dataset.zip(( - window.padded_batch( - 4, padded_shapes=tensor_shape.TensorShape([None])), - window.padded_batch( - 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), - )) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) - .apply(grouping.group_by_window( - lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), - reduce_func, 4)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - counts = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - tight_result, multiple_of_10_result = sess.run(get_next) - self.assertEqual(0, multiple_of_10_result.shape[1] % 10) - self.assertAllEqual(tight_result, - multiple_of_10_result[:, :tight_result.shape[1]]) - counts.append(tight_result.shape[0]) - self.assertEqual(len(components), sum(counts)) - - -# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. -# Currently, they use a constant batch size, though should be made to use a -# different batch size per key. -class BucketTest(test.TestCase): - - def _dynamicPad(self, bucket, window, window_size): - # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a - # generic form of padded_batch that pads every component - # dynamically and does not rely on static shape information about - # the arguments. - return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), - window.padded_batch( - 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( - [None]), tensor_shape.TensorShape([3]))))) - - def testSingleBucket(self): - - def _map_fn(v): - return (v, array_ops.fill([v], v), - array_ops.fill([3], string_ops.as_string(v))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda x, y, z: 0, - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - which_bucket, bucketed_values = sess.run(get_next) - - self.assertEqual(0, which_bucket) - - expected_scalar_int = np.arange(32, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) - for i in range(32): - expected_unk_int64[i, :i] = i - expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values[2]) - - def testEvenOddBuckets(self): - - def _map_fn(v): - return (v, array_ops.fill([v], v), - array_ops.fill([3], string_ops.as_string(v))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - # Get two minibatches (one containing even values, one containing odds) - which_bucket_even, bucketed_values_even = sess.run(get_next) - which_bucket_odd, bucketed_values_odd = sess.run(get_next) - - # Count number of bucket_tensors. - self.assertEqual(3, len(bucketed_values_even)) - self.assertEqual(3, len(bucketed_values_odd)) - - # Ensure bucket 0 was used for all minibatch entries. - self.assertAllEqual(0, which_bucket_even) - self.assertAllEqual(1, which_bucket_odd) - - # Test the first bucket outputted, the events starting at 0 - expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) - for i in range(0, 32): - expected_unk_int64[i, :2 * i] = 2 * i - expected_vec3_str = np.vstack( - 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values_even[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values_even[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values_even[2]) - - # Test the second bucket outputted, the odds starting at 1 - expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64) - expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) - for i in range(0, 32): - expected_unk_int64[i, :2 * i + 1] = 2 * i + 1 - expected_vec3_str = np.vstack( - 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T - - self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0]) - self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) - self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) - - def testEvenOddBucketsFilterOutAllOdd(self): - - def _map_fn(v): - return { - "x": v, - "y": array_ops.fill([v], v), - "z": array_ops.fill([3], string_ops.as_string(v)) - } - - def _dynamic_pad_fn(bucket, window, _): - return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), - window.padded_batch( - 32, { - "x": tensor_shape.TensorShape([]), - "y": tensor_shape.TensorShape([None]), - "z": tensor_shape.TensorShape([3]) - }))) - - input_dataset = ( - dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) - .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) - - bucketed_dataset = input_dataset.apply( - grouping.group_by_window( - lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), - lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) - - iterator = bucketed_dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - - # Get two minibatches ([0, 2, ...] and [64, 66, ...]) - which_bucket0, bucketed_values_even0 = sess.run(get_next) - which_bucket1, bucketed_values_even1 = sess.run(get_next) - - # Ensure that bucket 1 was completely filtered out - self.assertAllEqual(0, which_bucket0) - self.assertAllEqual(0, which_bucket1) - self.assertAllEqual( - np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"]) - self.assertAllEqual( - np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) - - def testDynamicWindowSize(self): - components = np.arange(100).astype(np.int64) - - # Key fn: even/odd - # Reduce fn: batches of 5 - # Window size fn: even=5, odd=10 - - def window_size_func(key): - window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64) - return window_sizes[key] - - dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20), - None, window_size_func)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.cached_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.OutOfRangeError): - batches = 0 - while True: - result = sess.run(get_next) - is_even = all(x % 2 == 0 for x in result) - is_odd = all(x % 2 == 1 for x in result) - self.assertTrue(is_even or is_odd) - expected_batch_size = 5 if is_even else 10 - self.assertEqual(expected_batch_size, result.shape[0]) - batches += 1 - - self.assertEqual(batches, 15) - - -def _element_length_fn(x, y=None): - del y - return array_ops.shape(x)[0] - - -class BucketBySequenceLength(test.TestCase): - - def testBucket(self): - - boundaries = [10, 20, 30] - batch_sizes = [10, 8, 4, 2] - lengths = [8, 13, 25, 35] - - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes, lengths): - record_len = length - 1 - for _ in range(batch_size): - elements.append([1] * record_len) - record_len = length - random.shuffle(elements) - for el in elements: - yield (el,) - - def _test_bucket_by_padding(no_padding): - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)) - if no_padding: - dataset = dataset.map(lambda x: (layers.dense_to_sparse(x),)) - dataset = dataset.apply( - grouping.bucket_by_sequence_length( - _element_length_fn, - boundaries, - batch_sizes, - no_padding=no_padding)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - shape = batch.dense_shape if no_padding else batch.shape - batch_size = shape[0] - length = shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - sum_check = batch.values.sum() if no_padding else batch.sum() - self.assertEqual(sum_check, batch_size * length - 1) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) - - for no_padding in (True, False): - _test_bucket_by_padding(no_padding) - - def testPadToBoundary(self): - - boundaries = [10, 20, 30] - batch_sizes = [10, 8, 4, 2] - lengths = [8, 13, 25] - - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes[:-1], lengths): - for _ in range(batch_size): - elements.append([1] * length) - random.shuffle(elements) - for el in elements: - yield (el,) - for _ in range(batch_sizes[-1]): - el = [1] * (boundaries[-1] + 5) - yield (el,) - - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes, - pad_to_bucket_boundary=True)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(3): - batches.append(sess.run(batch)) - with self.assertRaisesOpError("bucket_boundaries"): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - batch_sizes = batch_sizes[:-1] - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], - sorted(lengths_val)) - - def testPadToBoundaryNoExtraneousPadding(self): - - boundaries = [3, 7, 11] - batch_sizes = [2, 2, 2, 2] - lengths = range(1, 11) - - def element_gen(): - for length in lengths: - yield ([1] * length,) - - element_len = lambda element: array_ops.shape(element)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes, - pad_to_bucket_boundary=True)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.cached_session() as sess: - batches = [] - for _ in range(5): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - - self.assertAllEqual(batches[0], [[1, 0], - [1, 1]]) - self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0]]) - self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1]]) - self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) - self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - - def testTupleElements(self): - - def elements_gen(): - text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] - label = [1, 2, 1, 2] - for x, y in zip(text, label): - yield (x, y) - - def _test_tuple_elements_by_padding(no_padding): - dataset = dataset_ops.Dataset.from_generator( - generator=elements_gen, - output_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([])), - output_types=(dtypes.int32, dtypes.int32)) - if no_padding: - dataset = dataset.map(lambda x, y: (layers.dense_to_sparse(x), y)) - dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=_element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8], - no_padding=no_padding)) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) - - for no_padding in (True, False): - _test_tuple_elements_by_padding(no_padding) - - def testBucketSparse(self): - """Tests bucketing of sparse tensors (case where `no_padding` == True). - - Test runs on following dataset: - [ - [0], - [0, 1], - [0, 1, 2] - ... - [0, ..., max_len - 1] - ] - Sequences are bucketed by length and batched with - `batch_size` < `bucket_size`. - """ - - min_len = 0 - max_len = 100 - batch_size = 7 - bucket_size = 10 - - def _build_dataset(): - input_data = [range(i+1) for i in range(min_len, max_len)] - def generator_fn(): - for record in input_data: - yield record - dataset = dataset_ops.Dataset.from_generator( - generator=generator_fn, - output_shapes=(tensor_shape.TensorShape([None])), - output_types=(dtypes.int64)) - dataset = dataset.map(lambda x: layers.dense_to_sparse(x, eos_token=-1)) - return dataset - - def _compute_expected_batches(): - """Computes expected batch outputs and stores in a set.""" - all_expected_sparse_tensors = set() - for bucket_start_len in range(min_len, max_len, bucket_size): - for batch_offset in range(0, bucket_size, batch_size): - batch_start_len = bucket_start_len + batch_offset - batch_end_len = min(batch_start_len + batch_size, - bucket_start_len + bucket_size) - expected_indices = [] - expected_values = [] - for length in range(batch_start_len, batch_end_len): - for val in range(length + 1): - expected_indices.append((length - batch_start_len, val)) - expected_values.append(val) - expected_sprs_tensor = (tuple(expected_indices), - tuple(expected_values)) - all_expected_sparse_tensors.add(expected_sprs_tensor) - return all_expected_sparse_tensors - - def _compute_batches(dataset): - """Computes actual batch outputs of dataset and stores in a set.""" - batch = dataset.make_one_shot_iterator().get_next() - all_sparse_tensors = set() - with self.cached_session() as sess: - with self.assertRaises(errors.OutOfRangeError): - while True: - output = sess.run(batch) - sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), - tuple(output.values)) - all_sparse_tensors.add(sprs_tensor) - return all_sparse_tensors - - dataset = _build_dataset() - boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) - dataset = dataset.apply(grouping.bucket_by_sequence_length( - _element_length_fn, - boundaries, - [batch_size] * (len(boundaries) + 1), - no_padding=True)) - batches = _compute_batches(dataset) - expected_batches = _compute_expected_batches() - self.assertEqual(batches, expected_batches) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 1cc5ddc9a2e1eff4473c19bc397d656e7e0b90ed..d2a72272db159755ac2d741bcdbce9ec646d928e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -22,6 +22,7 @@ import os import shutil from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,7 @@ from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" -class LMDBDatasetTest(test.TestCase): +class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): super(LMDBDatasetTest, self).setUp() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD deleted file mode 100644 index 7e9ea68047a076d368cf98960f4754b29abb074e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ /dev/null @@ -1,107 +0,0 @@ -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -py_test( - name = "assert_next_dataset_op_test", - size = "medium", - srcs = ["assert_next_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_test( - name = "latency_all_edges_test", - size = "small", - srcs = ["latency_all_edges_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -py_test( - name = "map_vectorization_test", - size = "small", - srcs = ["map_vectorization_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/kernel_tests:test_utils", - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "map_and_filter_fusion_test", - size = "medium", - srcs = ["map_and_filter_fusion_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - "@absl_py//absl/testing:parameterized", - ], -) - -py_test( - name = "model_dataset_op_test", - size = "medium", - srcs = ["model_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "optonly", - ], - deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "optimize_dataset_op_test", - size = "small", - srcs = ["optimize_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py deleted file mode 100644 index 0166ba0d44ef473ac54ee4f67078c1a51fddacf3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ /dev/null @@ -1,1098 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for prefetching_ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import threading - -from tensorflow.contrib.data.python.ops import prefetching_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.compat import compat -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.platform import test - - -class PrefetchingKernelsOpsTest(test.TestCase): - - def setUp(self): - self._event = threading.Event() - - def _create_ds_and_iterator(self, device0, initializable=False): - - def gen(): - for i in range(1, 10): - yield [float(i)] - if i == 6: - self._event.set() - - with ops.device(device0): - ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) - if initializable: - ds_iterator = ds.make_initializable_iterator() - else: - ds_iterator = ds.make_one_shot_iterator() - return (ds, ds_iterator) - - def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): - ds_iterator_handle = ds_iterator.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, ds.output_types, ds.output_shapes) - return remote_iterator.get_next() - - target = constant_op.constant(device0) - with ops.device(device1): - buffer_resource_handle = prefetching_ops.function_buffering_resource( - f=_remote_fn, - output_types=[dtypes.float32], - target_device=target, - string_arg=ds_iterator_handle, - buffer_size=3, - shared_name=buffer_name) - - with ops.device(device1): - prefetch_op = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=buffer_resource_handle, - output_types=[dtypes.float32]) - reset_op = prefetching_ops.function_buffering_resource_reset( - function_buffer_resource=buffer_resource_handle) - destroy_op = resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True) - - return (prefetch_op, reset_op, destroy_op) - - def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - - ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) - prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, - device0, device1) - - with self.test_session(config=worker_config) as sess: - elem = sess.run(prefetch_op) - self.assertEqual(elem, [1.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [2.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [3.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [4.0]) - self._event.wait() - elem = sess.run(prefetch_op) - self.assertEqual(elem, [5.0]) - sess.run(destroy_op) - - def testSameDeviceCPU(self): - self._prefetch_fn_helper_one_shot("same_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:0") - - def testDifferentDeviceCPU(self): - self._prefetch_fn_helper_one_shot("diff_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:1") - - def testDifferentDeviceCPUGPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - self._prefetch_fn_helper_one_shot("cpu_gpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/gpu:0") - - def testReinitialization(self): - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - - device0 = "/job:localhost/replica:0/task:0/cpu:0" - device1 = "/job:localhost/replica:0/task:0/cpu:1" - ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) - prefetch_op, reset_op, destroy_op = self._create_ops( - ds, ds_iterator, "reinit", device0, device1) - - with self.test_session(config=worker_config) as sess: - sess.run(ds_iterator.initializer) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [1.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [2.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [3.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [4.0]) - self._event.wait() - elem = sess.run(prefetch_op) - self.assertEqual(elem, [5.0]) - # Lets reset the function buffering resource and reinitialize the - # iterator. Should be able to go through this again. - self._event.clear() - sess.run(reset_op) - sess.run(ds_iterator.initializer) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [1.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [2.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [3.0]) - elem = sess.run(prefetch_op) - self.assertEqual(elem, [4.0]) - self._event.wait() - elem = sess.run(prefetch_op) - self.assertEqual(elem, [5.0]) - sess.run(destroy_op) - - def testReinitializationOutOfRange(self): - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - - device0 = "/job:localhost/replica:0/task:0/cpu:0" - device1 = "/job:localhost/replica:0/task:0/cpu:1" - ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) - prefetch_op, reset_op, destroy_op = self._create_ops( - ds, ds_iterator, "reinit", device0, device1) - - with self.test_session(config=worker_config) as sess: - sess.run(ds_iterator.initializer) - for i in range(1, 10): - elem = sess.run(prefetch_op) - self.assertEqual(elem, [float(i)]) - # Try fetching after its over twice to test out end of sequence. - with self.assertRaises(errors.OutOfRangeError): - sess.run(prefetch_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(prefetch_op) - - # Now reset everything and try it out again. - self._event.clear() - sess.run(reset_op) - sess.run(ds_iterator.initializer) - for i in range(1, 10): - elem = sess.run(prefetch_op) - self.assertEqual(elem, [float(i)]) - # Try fetching after its over twice to test out end of sequence. - with self.assertRaises(errors.OutOfRangeError): - sess.run(prefetch_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(prefetch_op) - - sess.run(destroy_op) - - def testStringsGPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - device0 = "/job:localhost/replica:0/task:0/cpu:0" - device1 = "/job:localhost/replica:0/task:0/gpu:0" - - ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) - ds_iterator = ds.make_one_shot_iterator() - ds_iterator_handle = ds_iterator.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, ds.output_types, ds.output_shapes) - return remote_iterator.get_next() - - target = constant_op.constant(device0) - with ops.device(device1): - buffer_resource_handle = prefetching_ops.function_buffering_resource( - f=_remote_fn, - output_types=[dtypes.string], - target_device=target, - string_arg=ds_iterator_handle, - buffer_size=3, - shared_name="strings") - - with ops.device(device1): - prefetch_op = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=buffer_resource_handle, - output_types=[dtypes.string]) - destroy_op = resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True) - - with self.cached_session() as sess: - self.assertEqual([b"a"], sess.run(prefetch_op)) - self.assertEqual([b"b"], sess.run(prefetch_op)) - self.assertEqual([b"c"], sess.run(prefetch_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(prefetch_op) - - sess.run(destroy_op) - - -class PrefetchToDeviceTest(test.TestCase): - - def testPrefetchToDevice(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/cpu:1")) - - # NOTE(mrry): This device block creates the "host" dataset and iterator on - # /cpu:0, and ensures that the prefetching is across devices. In typical use - # this would not be necessary, because the GPU device would not support any - # of the dataset-related ops. - with ops.device("/cpu:0"): - iterator = device_dataset.make_one_shot_iterator() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - next_element = iterator.get_next() - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToSameDevice(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device( - "/job:localhost/replica:0/task:0/device:CPU:0")) - - # NOTE(mrry): This device block creates the "host" dataset and iterator on - # /cpu:0, and ensures that the prefetching is across devices. In typical use - # this would not be necessary, because the GPU device would not support any - # of the dataset-related ops. - with ops.device("/cpu:0"): - iterator = device_dataset.make_one_shot_iterator() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - next_element = iterator.get_next() - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchDictToDevice(self): - host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/cpu:1")) - - # NOTE(mrry): This device block creates the "host" dataset and iterator on - # /cpu:0, and ensures that the prefetching is across devices. In typical use - # this would not be necessary, because the GPU device would not support any - # of the dataset-related ops. - with ops.device("/cpu:0"): - iterator = device_dataset.make_one_shot_iterator() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - next_element = iterator.get_next() - self.assertEqual(dtypes.int64, next_element["a"].dtype) - self.assertEqual([], next_element["a"].shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual({"a": i}, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchSparseTensorsToDevice(self): - def make_tensor(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2]) - host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) - - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/cpu:1")) - - # NOTE(mrry): This device block creates the "host" dataset and iterator on - # /cpu:0, and ensures that the prefetching is across devices. In typical use - # this would not be necessary, because the GPU device would not support any - # of the dataset-related ops. - with ops.device("/cpu:0"): - iterator = device_dataset.make_one_shot_iterator() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - next_element = iterator.get_next() - self.assertEqual(dtypes.int64, next_element.dtype) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - actual = sess.run(next_element) - self.assertAllEqual([i], actual.values) - self.assertAllEqual([[0, 0]], actual.indices) - self.assertAllEqual([2, 2], actual.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToDeviceGpu(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToDeviceWithReInit(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/cpu:1")) - - # NOTE(mrry): This device block creates the "host" dataset and iterator on - # /cpu:0, and ensures that the prefetching is across devices. In typical use - # this would not be necessary, because the GPU device would not support any - # of the dataset-related ops. - with ops.device("/cpu:0"): - iterator = device_dataset.make_initializable_iterator() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - next_element = iterator.get_next() - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToDeviceGpuWithReInit(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - -class CopyToDeviceTest(test.TestCase): - - def testCopyToDevice(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceInt32(self): - host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int32, next_element.dtype) - self.assertEqual((4,), next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToSameDevice(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:0")) - - with ops.device("/cpu:0"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceWithPrefetch(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyDictToDevice(self): - host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element["a"].dtype) - self.assertEqual([], next_element["a"].shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual({"a": i}, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyDictToDeviceWithPrefetch(self): - host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element["a"].dtype) - self.assertEqual([], next_element["a"].shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - self.assertEqual({"a": i}, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopySparseTensorsToDevice(self): - - def make_tensor(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) - - host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) - - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - actual = sess.run(next_element) - self.assertAllEqual([i], actual.values) - self.assertAllEqual([[0, 0]], actual.indices) - self.assertAllEqual([2, 2], actual.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopySparseTensorsToDeviceWithPrefetch(self): - - def make_tensor(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) - - host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) - - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - for i in range(10): - actual = sess.run(next_element) - self.assertAllEqual([i], actual.values) - self.assertAllEqual([[0, 0]], actual.indices) - self.assertAllEqual([2, 2], actual.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpu(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuWithPrefetch(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuInt32(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuInt32AndPrefetch(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuStrings(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuStringsAndPrefetch(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDevicePingPongCPUGPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with compat.forward_compatibility_horizon(2018, 8, 4): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0")) - back_to_cpu_dataset = device_dataset.apply( - prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0")) - - with ops.device("/cpu:0"): - iterator = back_to_cpu_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceWithReInit(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceWithReInitAndPrefetch(self): - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) - - with ops.device("/cpu:1"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - self.assertEqual(host_dataset.output_types, device_dataset.output_types) - self.assertEqual(host_dataset.output_types, iterator.output_types) - self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) - self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) - self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) - self.assertEqual(host_dataset.output_classes, iterator.output_classes) - - self.assertEqual(dtypes.int64, next_element.dtype) - self.assertEqual([], next_element.shape) - - worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=worker_config) as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuWithReInit(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testCopyToDeviceGpuWithReInitAndPrefetch(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) - - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - sess.run(iterator.initializer) - for i in range(5): - self.assertEqual(i, sess.run(next_element)) - sess.run(iterator.initializer) - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testIteratorGetNextAsOptionalOnGPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(3) - device_dataset = host_dataset.apply( - prefetching_ops.copy_to_device("/gpu:0")) - with ops.device("/gpu:0"): - iterator = device_dataset.make_initializable_iterator() - next_elem = iterator_ops.get_next_as_optional(iterator) - elem_has_value_t = next_elem.has_value() - elem_value_t = next_elem.get_value() - - with self.cached_session() as sess: - # Before initializing the iterator, evaluating the optional fails with - # a FailedPreconditionError. - with self.assertRaises(errors.FailedPreconditionError): - sess.run(elem_has_value_t) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(elem_value_t) - - # For each element of the dataset, assert that the optional evaluates to - # the expected value. - sess.run(iterator.initializer) - for i in range(3): - elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) - self.assertTrue(elem_has_value) - self.assertEqual(i, elem_value) - - # After exhausting the iterator, `next_elem.has_value()` will evaluate to - # false, and attempting to get the value will fail. - for _ in range(2): - self.assertFalse(sess.run(elem_has_value_t)) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(elem_value_t) - - -class MultiDeviceIteratorTest(test.TestCase): - - def testBasic(self): - dataset = dataset_ops.Dataset.range(10) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"]) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 3}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 10, 2): - self.assertEqual(i, sess.run(elem_on_1)) - self.assertEqual(i + 1, sess.run(elem_on_2)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - def testOneOnSameDevice(self): - with ops.device("/cpu:0"): - dataset = dataset_ops.Dataset.range(10) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:0", "/cpu:1"]) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 2}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 10, 2): - self.assertEqual(i, sess.run(elem_on_1)) - self.assertEqual(i + 1, sess.run(elem_on_2)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - def testRepeatDevices(self): - with ops.device("/cpu:0"): - dataset = dataset_ops.Dataset.range(20) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) - elements = multi_device_iterator.get_next() - elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements - - config = config_pb2.ConfigProto(device_count={"CPU": 3}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 20, 4): - self.assertEqual(i, sess.run(elem_on_1)) - self.assertEqual(i + 1, sess.run(elem_on_2)) - self.assertEqual(i + 2, sess.run(elem_on_3)) - self.assertEqual(i + 3, sess.run(elem_on_4)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - sess.run(elem_on_3) - sess.run(elem_on_4) - - def testNotFullyDivisible(self): - dataset = dataset_ops.Dataset.range(9) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"]) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 3}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 8, 2): - self.assertEqual(i, sess.run(elem_on_1)) - self.assertEqual(i + 1, sess.run(elem_on_2)) - self.assertEqual(8, sess.run(elem_on_1)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - def testUneven(self): - dataset = dataset_ops.Dataset.range(10) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 3}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 10, 2): - self.assertEqual(i, sess.run(elem_on_1)) - for i in range(0, 10, 2): - self.assertEqual(i + 1, sess.run(elem_on_2)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - def testMultipleInitializations(self): - with ops.device("/cpu:0"): - epoch = array_ops.placeholder(dtypes.int64, shape=[]) - dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) - dataset2 = dataset_ops.Dataset.range(1000) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - init_op = multi_device_iterator.initializer - - config = config_pb2.ConfigProto(device_count={"CPU": 3}) - with self.test_session(config=config) as sess: - for i in range(1000): - sess.run(init_op, feed_dict={epoch: i}) - self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2])) - - def testBasicGpu(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with compat.forward_compatibility_horizon(2018, 8, 4): - dataset = dataset_ops.Dataset.range(10) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/gpu:0"]) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 10, 2): - self.assertEqual(i, sess.run(elem_on_1)) - self.assertEqual(i + 1, sess.run(elem_on_2)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - def testUnevenGpu(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with compat.forward_compatibility_horizon(2018, 8, 4): - dataset = dataset_ops.Dataset.range(10) - multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) - elem_on_1, elem_on_2 = multi_device_iterator.get_next() - - config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) - with self.test_session(config=config) as sess: - sess.run(multi_device_iterator.initializer) - for i in range(0, 10, 2): - self.assertEqual(i, sess.run(elem_on_1)) - for i in range(0, 10, 2): - self.assertEqual(i + 1, sess.run(elem_on_2)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(elem_on_1) - sess.run(elem_on_2) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7281d531870c75c638b5c48fa3fc6dc606a3623 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + + @parameterized.named_parameters( + ("SumZero", 0), + ("SumOne", 1), + ("SumFive", 5), + ("SumTen", 10), + ) + def testReduceDataset(self, stop): + def init_fn(_): + return np.int64(0) + + def reduce_fn(state, value): + return state + value + + def finalize_fn(state): + return state + + sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + + stop_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset_ops.Dataset.range(stop_t) + element = get_single_element.reduce_dataset(dataset, sum_reducer) + + with self.cached_session() as sess: + value = sess.run(element, feed_dict={stop_t: stop}) + self.assertEqual(stop * (stop - 1) / 2, value) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 90d18dca2aa727ea51d636cb971f48b50bc0c663..c5a786232252432481566e3cde23e9310df172cc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class SlideDatasetTest(test.TestCase, parameterized.TestCase): +class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("1", 20, 14, 7, 1), @@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): sliding.sliding_window_batch( window_size=1, stride=1, window_shift=1, window_stride=1)) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSlideSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py deleted file mode 100644 index 4c3353fe4046d6b2bfabac580b46f88c8d7f2941..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test utilities for tf.data functionality.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re - -from tensorflow.python.data.util import nest -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DatasetTestBase(test.TestCase): - """Base class for dataset tests.""" - - def _assert_datasets_equal(self, dataset1, dataset2): - # TODO(rachelim): support sparse tensor outputs - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - while True: - try: - op1 = sess.run(next1) - except errors.OutOfRangeError: - with self.assertRaises(errors.OutOfRangeError): - sess.run(next2) - break - op2 = sess.run(next2) - - op1 = nest.flatten(op1) - op2 = nest.flatten(op2) - assert len(op1) == len(op2) - for i in range(len(op1)): - self.assertAllEqual(op1[i], op2[i]) - - def _assert_datasets_raise_same_error(self, - dataset1, - dataset2, - exception_class, - replacements=None): - # We are defining next1 and next2 in the same line so that we get identical - # file:line_number in the error messages - # pylint: disable=line-too-long - next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() - # pylint: enable=line-too-long - with self.cached_session() as sess: - try: - sess.run(next1) - raise ValueError( - "Expected dataset to raise an error of type %s, but it did not." % - repr(exception_class)) - except exception_class as e: - expected_message = e.message - for old, new, count in replacements: - expected_message = expected_message.replace(old, new, count) - # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exception_class, - re.escape(expected_message)): - sess.run(next2) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py deleted file mode 100644 index 6eaa0b195911acb057b30b8ca7408cdbfdce8352..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ /dev/null @@ -1,525 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.platform import test - - -class WindowDatasetTest(test.TestCase, parameterized.TestCase): - - def _structuredDataset(self, structure, shape, dtype): - if structure is None: - return dataset_ops.Dataset.from_tensors( - array_ops.zeros(shape, dtype=dtype)) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredDataset(substructure, shape, dtype) - for substructure in structure - ])) - - def _structuredElement(self, structure, shape, dtype): - if structure is None: - return array_ops.zeros(shape, dtype=dtype) - else: - return tuple([ - self._structuredElement(substructure, shape, dtype) - for substructure in structure - ]) - - def _assertEqual(self, xs, ys): - self.assertEqual(type(xs), type(ys)) - if isinstance(xs, tuple) and isinstance(ys, tuple): - self.assertEqual(len(xs), len(ys)) - for x, y in zip(xs, ys): - self._assertEqual(x, y) - elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray): - self.assertAllEqual(xs, ys) - else: - self.assertEqual(xs, ys) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetFlatMap(self, structure, shape, dtype): - """Tests windowing by chaining it with flat map. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return args[0] - return dataset_ops.Dataset.zip( - tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) - - dataset = self._structuredDataset(structure, shape, dtype).apply( - grouping.window_dataset(5)).flat_map(fn) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run(self._structuredElement(structure, shape, dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetBatchDense(self, structure, shape, dtype): - """Tests batching of dense tensor windows. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.batch_window(args[0]) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) - for arg in args - ]) - - dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( - grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredElement(structure, np.concatenate( - ([5], shape), axis=0), dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([])), - ("2", np.int32([1])), - ("3", np.int32([1, 2, 3])), - ) - def testWindowDatasetBatchDenseDynamicShape(self, shape): - """Tests batching of dynamically shaped dense tensor windows. - - Args: - shape: the input shape - """ - - shape_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensors( - array_ops.zeros(shape_t)).repeat(5).apply( - grouping.window_dataset(5)).apply( - grouping._map_x_dataset(batching.batch_window)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shape_t: shape}) - expected = sess.run( - self._structuredElement(None, np.concatenate(([5], shape), axis=0), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - def _make_dense_to_sparse_fn(self, is_scalar): - - def dense_to_sparse_scalar(tensor): - indices = [[]] - values = array_ops.expand_dims(tensor, 0) - shape = [] - return sparse_tensor.SparseTensorValue(indices, values, shape) - - def dense_to_sparse_non_scalar(tensor): - indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool)) - values = array_ops.gather_nd(tensor, indices) - shape = array_ops.shape(tensor, out_type=dtypes.int64) - return sparse_tensor.SparseTensorValue(indices, values, shape) - - if is_scalar: - return dense_to_sparse_scalar - return dense_to_sparse_non_scalar - - def _structuredSparseDataset(self, structure, shape, dtype): - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - if structure is None: - return dataset_ops.Dataset.from_tensors( - dense_to_sparse(array_ops.zeros(shape, dtype=dtype))) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredSparseDataset(substructure, shape, dtype) - for substructure in structure - ])) - - def _structuredSparseElement(self, structure, shape, dtype): - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - if structure is None: - return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - else: - return tuple([ - self._structuredSparseElement(substructure, shape, dtype) - for substructure in structure - ]) - - @parameterized.named_parameters( - ("1", None, np.int32([]), dtypes.bool), - ("2", None, np.int32([]), dtypes.int32), - ("3", None, np.int32([]), dtypes.float32), - ("4", None, np.int32([]), dtypes.string), - ("5", None, np.int32([2]), dtypes.int32), - ("6", None, np.int32([2, 2]), dtypes.int32), - ("7", (None, None, None), np.int32([]), dtypes.int32), - ("8", (None, (None, None)), np.int32([]), dtypes.int32), - ) - def testWindowDatasetBatchSparse(self, structure, shape, dtype): - """Tests batching of sparse tensor windows. - - Args: - structure: the input structure - shape: the input shape - dtype: the input data type - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.batch_window(args[0]) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) - for arg in args - ]) - - dataset = self._structuredSparseDataset( - structure, shape, dtype).repeat(5).apply( - grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredSparseElement(structure, - np.concatenate(([5], shape), axis=0), - dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([])), - ("2", np.int32([1])), - ("3", np.int32([1, 2, 3])), - ) - def testWindowDatasetBatchSparseDynamicShape(self, shape): - """Tests batching of dynamically shaped sparse tensor windows. - - Args: - shape: the input shape - """ - - shape_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map( - self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test - grouping.window_dataset(5)).apply( - grouping._map_x_dataset(batching.batch_window)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shape_t: shape}) - expected = sess.run( - self._structuredSparseElement(None, - np.concatenate(([5], shape), axis=0), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - def _structuredRaggedDataset(self, structure, shapes, dtype): - - if structure is None: - return dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtype)) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredRaggedDataset(substructure, shapes, dtype) - for substructure in structure - ])) - - @parameterized.named_parameters( - ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), - ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), - ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), - ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), - ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("8", (None, - (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), - ) - def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, - padded_shape): - """Tests padded batching of dense tensor windows. - - Args: - structure: the input structure - shapes: the input shapes - dtype: the input data type - padded_shape: the shape to pad the output to - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.padded_batch_window(args[0], padded_shape) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( - arg, padded_shape) for arg in args - ]) - - dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - expected = sess.run( - self._structuredElement( - structure, - np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([[1], [2], [3]]), [-1]), - ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), - ) - def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): - """Tests padded batching of dynamically shaped dense tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - shapes_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shapes_t: shapes}) - expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - expected = sess.run( - self._structuredElement( - None, np.concatenate((np.int32([len(shapes)]), expected_shape)), - dtypes.int32)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int32([[1]]), np.int32([0])), - ("2", np.int32([[10], [20]]), np.int32([15])), - ) - def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): - """Tests invalid padded batching of dense tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( - grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - def _structuredRaggedSparseDataset(self, structure, shapes, dtype): - - def map_fn(shape): - dense_to_sparse = self._make_dense_to_sparse_fn(False) - return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - - if structure is None: - return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn) - else: - return dataset_ops.Dataset.zip( - tuple([ - self._structuredRaggedSparseDataset(substructure, shapes, dtype) - for substructure in structure - ])) - - def _structuredRaggedSparseElement(self, structure, shapes, dtype, - padded_shape): - if structure is None: - dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) - values = [] - for shape in shapes: - dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test - sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) - padded_sparse = sparse_tensor.SparseTensor(sparse.indices, - sparse.values, dense_shape) - reshaped_sparse = sparse_ops.sparse_reshape( - padded_sparse, - array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0)) - values.append(reshaped_sparse) - return sparse_ops.sparse_concat(0, values) - else: - return tuple([ - self._structuredRaggedSparseElement(substructure, shapes, dtype, - padded_shape) - for substructure in structure - ]) - - @parameterized.named_parameters( - ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), - ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), - ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), - ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), - ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("8", (None, - (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), - ) - def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, - padded_shape): - """Tests padded batching of sparse tensor windows. - - Args: - structure: the input structure - shapes: the input shapes - dtype: the input data type - padded_shape: the shape to pad the output to - """ - - def fn(*args): - if len(args) == 1 and not isinstance(args[0], tuple): - return batching.padded_batch_window(args[0], padded_shape) - - return tuple([ - fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( - arg, padded_shape) for arg in args - ]) - - dataset = self._structuredRaggedSparseDataset( - structure, shapes, dtype).apply(grouping.window_dataset( - len(shapes))).apply(grouping._map_x_dataset(fn)) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - expected = sess.run( - self._structuredRaggedSparseElement(structure, shapes, dtype, - padded_shape)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int64([[1], [2], [3]]), [-1]), - ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), - ) - def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, - padded_shape): - """Tests padded batching of dynamically shaped sparse tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - shapes_t = array_ops.placeholder(dtypes.int32) - dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( - self._make_dense_to_sparse_fn(False) - ).apply(grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - with self.cached_session() as sess: - sess.run(init_op, {shapes_t: shapes}) - expected = sess.run( - self._structuredRaggedSparseElement(None, shapes, dtypes.int32, - padded_shape)) - actual = sess.run(get_next) - self._assertEqual(expected, actual) - - @parameterized.named_parameters( - ("1", np.int64([[1]]), [0]), - ("2", np.int64([[10], [20]]), [15]), - ) - def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): - """Tests invalid padded batching of sparse tensor windows. - - Args: - shapes: the input shapes - padded_shape: the shape to pad the output to - """ - - dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( - lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( - self._make_dense_to_sparse_fn(False) - ).apply(grouping.window_dataset(len(shapes))).apply( - grouping._map_x_dataset( - lambda x: batching.padded_batch_window(x, padded_shape))) - get_next = dataset.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 4b45cc7e36d14e99d1132b919dfc175a1217f8b9..34dc2379d0cb38f8f6962fa42efe21b793bc8d65 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -16,10 +16,7 @@ py_library( srcs = ["counter.py"], srcs_version = "PY2AND3", deps = [ - ":scan_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:counter", ], ) @@ -28,12 +25,7 @@ py_library( srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ - ":grouping", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - "//third_party/py/numpy", + "//tensorflow/python/data/experimental/ops:get_single_element", ], ) @@ -44,10 +36,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/experimental/ops:iterator_ops", ], ) @@ -58,15 +47,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:random_ops", ], ) @@ -78,18 +59,19 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", - ":gen_dataset_ops", ":interleave_ops", ":parsing_ops", ":shuffle_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/experimental/ops:readers", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:convert", @@ -105,7 +87,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:shuffle_ops", ], ) @@ -124,6 +106,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", @@ -137,8 +120,7 @@ py_library( srcs = ["enumerate_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/experimental/ops:enumerate_ops", ], ) @@ -147,11 +129,7 @@ py_library( srcs = ["error_ops.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:error_ops", ], ) @@ -160,16 +138,7 @@ py_library( srcs = ["grouping.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:grouping", ], ) @@ -178,32 +147,7 @@ py_library( srcs = ["interleave_ops.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - ":random_ops", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - -py_library( - name = "optimization", - srcs = ["optimization.py"], - srcs_version = "PY2AND3", - deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:interleave_ops", ], ) @@ -212,25 +156,7 @@ py_library( srcs = ["parsing_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -py_library( - name = "map_defun", - srcs = ["map_defun.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/experimental/ops:parsing_ops", ], ) @@ -239,18 +165,7 @@ py_library( srcs = ["resampling.py"], srcs_version = "PY2AND3", deps = [ - ":batching", - ":interleave_ops", - ":scan_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:logging_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "//tensorflow/python/data/experimental/ops:resampling", ], ) @@ -259,12 +174,7 @@ py_library( srcs = ["scan_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:scan_ops", ], ) @@ -283,33 +193,12 @@ py_library( ], ) -py_library( - name = "stats_ops", - srcs = ["stats_ops.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - py_library( name = "threadpool", srcs = ["threadpool.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - "//tensorflow/python/eager:context", + "//tensorflow/python/data/experimental/ops:threadpool", ], ) @@ -320,12 +209,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:unique", ], ) @@ -336,56 +220,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python/data/ops:dataset_ops", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_dataset_ops", - out = "gen_dataset_ops.py", - deps = [ - "//tensorflow/contrib/data:dataset_ops_op_lib", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", - ], -) - -tf_kernel_library( - name = "dataset_ops_kernels", - deps = [ - "//tensorflow/contrib/data/kernels:dataset_kernels", - "//tensorflow/core:framework", - ], - alwayslink = 1, -) - -tf_custom_op_py_library( - name = "contrib_op_loader", - srcs = ["contrib_op_loader.py"], - dso = ["//tensorflow/contrib/data:_dataset_ops.so"], - kernels = [ - ":dataset_ops_kernels", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", - "//tensorflow/contrib/data:dataset_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_dataset_ops", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:platform", - ], -) - -py_library( - name = "indexed_dataset_ops", - srcs = ["indexed_dataset_ops.py"], - deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:writers", ], ) @@ -393,11 +228,7 @@ py_library( name = "prefetching_ops", srcs = ["prefetching_ops.py"], deps = [ - ":contrib_op_loader", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/experimental/ops:prefetching_ops", ], ) @@ -410,17 +241,14 @@ py_library( ":error_ops", ":get_single_element", ":grouping", - ":indexed_dataset_ops", ":interleave_ops", - ":map_defun", - ":optimization", ":prefetching_ops", + ":random_ops", ":readers", ":resampling", ":scan_ops", ":shuffle_ops", ":sliding", - ":stats_ops", ":threadpool", ":unique", ":writers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 367c159dc5db688b652f2e88a92e44186d7c8bfd..8c60459ca81cd7a7e08d90339011c54275ea9c0b 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,134 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.ops import get_single_element -from tensorflow.contrib.data.python.ops import grouping from tensorflow.contrib.framework import with_shape -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert +from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation -def batch_window(dataset): - """Batches a window of tensors. - - Args: - dataset: the input dataset. - - Returns: - A `Tensor` representing the batch of the entire input dataset. - """ - if isinstance(dataset.output_classes, tuple): - raise TypeError("Input dataset expected to have a single component") - if dataset.output_classes is ops.Tensor: - return _batch_dense_window(dataset) - elif dataset.output_classes is sparse_tensor.SparseTensor: - return _batch_sparse_window(dataset) - else: - raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) - - -def _batch_dense_window(dataset): - """Batches a window of dense tensors.""" - - def key_fn(_): - return np.int64(0) - - def shape_init_fn(_): - return array_ops.shape(first_element) - - def shape_reduce_fn(state, value): - check_ops.assert_equal(state, array_ops.shape(value)) - return state - - def finalize_fn(state): - return state - - if dataset.output_shapes.is_fully_defined(): - shape = dataset.output_shapes - else: - first_element = get_single_element.get_single_element(dataset.take(1)) - shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, - finalize_fn) - shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) - - def batch_init_fn(_): - batch_shape = array_ops.concat([[0], shape], 0) - return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) - - def batch_reduce_fn(state, value): - return array_ops.concat([state, [value]], 0) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) - - -def _batch_sparse_window(dataset): - """Batches a window of sparse tensors.""" - - def key_fn(_): - return np.int64(0) - - def shape_init_fn(_): - return first_element.dense_shape - - def shape_reduce_fn(state, value): - check_ops.assert_equal(state, value.dense_shape) - return state - - def finalize_fn(state): - return state - - if dataset.output_shapes.is_fully_defined(): - shape = dataset.output_shapes - else: - first_element = get_single_element.get_single_element(dataset.take(1)) - shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, - finalize_fn) - shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) - - def batch_init_fn(_): - indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) - return sparse_tensor.SparseTensor( - indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), - values=constant_op.constant([], shape=[0], dtype=dataset.output_types), - dense_shape=array_ops.concat( - [np.array([0], dtype=np.int64), - math_ops.cast(shape, dtypes.int64)], 0)) - - def batch_reduce_fn(state, value): - return sparse_ops.sparse_concat(0, [state, value]) - - def reshape_fn(value): - return sparse_ops.sparse_reshape( - value, - array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.map(reshape_fn).apply( - grouping.group_by_reducer(key_fn, batch_reducer))) - - +@deprecation.deprecated( + None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.") def dense_to_sparse_batch(batch_size, row_shape): """A transformation that batches ragged elements into `tf.SparseTensor`s. @@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) - - return _apply_fn - - -def padded_batch_window(dataset, padded_shape, padding_value=None): - """Batches a window of tensors with padding. - - Args: - dataset: the input dataset. - padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like - object representing the shape to which the input elements should be padded - prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a - `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the - maximum size of that dimension in each batch. - padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the - padding value to use. Defaults are `0` for numeric types and the empty - string for string types. If `dataset` contains `tf.SparseTensor`, this - value is ignored. - - Returns: - A `Tensor` representing the batch of the entire input dataset. - - Raises: - ValueError: if invalid arguments are provided. - """ - if not issubclass(dataset.output_classes, - (ops.Tensor, sparse_tensor.SparseTensor)): - raise TypeError("Input dataset expected to have a single tensor component") - if issubclass(dataset.output_classes, (ops.Tensor)): - return _padded_batch_dense_window(dataset, padded_shape, padding_value) - elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): - if padding_value is not None: - raise ValueError("Padding value not allowed for sparse tensors") - return _padded_batch_sparse_window(dataset, padded_shape) - else: - raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) - - -def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): - """Batches a window of dense tensors with padding.""" - - padded_shape = math_ops.cast( - convert.partial_shape_to_tensor(padded_shape), dtypes.int32) - - def key_fn(_): - return np.int64(0) - - def max_init_fn(_): - return padded_shape - - def max_reduce_fn(state, value): - """Computes the maximum shape to pad to.""" - condition = math_ops.reduce_all( - math_ops.logical_or( - math_ops.less_equal(array_ops.shape(value), padded_shape), - math_ops.equal(padded_shape, -1))) - assert_op = control_flow_ops.Assert(condition, [ - "Actual shape greater than padded shape: ", - array_ops.shape(value), padded_shape - ]) - with ops.control_dependencies([assert_op]): - return math_ops.maximum(state, array_ops.shape(value)) - - def finalize_fn(state): - return state - - # Compute the padded shape. - max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) - padded_shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) - - if padding_value is None: - if dataset.output_types == dtypes.string: - padding_value = "" - elif dataset.output_types == dtypes.bool: - padding_value = False - elif dataset.output_types == dtypes.variant: - raise TypeError("Unable to create padding for field of type 'variant'") - else: - padding_value = 0 - - def batch_init_fn(_): - batch_shape = array_ops.concat( - [np.array([0], dtype=np.int32), padded_shape], 0) - return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) - - def batch_reduce_fn(state, value): - return array_ops.concat([state, [value]], 0) - - def pad_fn(value): - shape = array_ops.shape(value) - left = array_ops.zeros_like(shape) - right = padded_shape - shape - return array_ops.pad( - value, array_ops.stack([left, right], 1), constant_values=padding_value) - - batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.map(pad_fn).apply( - grouping.group_by_reducer(key_fn, batch_reducer))) - - -def _padded_batch_sparse_window(dataset, padded_shape): - """Batches a window of sparse tensors with padding.""" - - def key_fn(_): - return np.int64(0) - - def max_init_fn(_): - return convert.partial_shape_to_tensor(padded_shape) - - def max_reduce_fn(state, value): - """Computes the maximum shape to pad to.""" - condition = math_ops.reduce_all( - math_ops.logical_or( - math_ops.less_equal(value.dense_shape, padded_shape), - math_ops.equal(padded_shape, -1))) - assert_op = control_flow_ops.Assert(condition, [ - "Actual shape greater than padded shape: ", value.dense_shape, - padded_shape - ]) - with ops.control_dependencies([assert_op]): - return math_ops.maximum(state, value.dense_shape) - - def finalize_fn(state): - return state - - # Compute the padded shape. - max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) - padded_shape = get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) - - def batch_init_fn(_): - indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], - 0) - return sparse_tensor.SparseTensor( - indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), - values=constant_op.constant([], shape=[0], dtype=dataset.output_types), - dense_shape=array_ops.concat( - [np.array([0], dtype=np.int64), padded_shape], 0)) - - def batch_reduce_fn(state, value): - padded_value = sparse_tensor.SparseTensor( - indices=value.indices, values=value.values, dense_shape=padded_shape) - reshaped_value = sparse_ops.sparse_reshape( - padded_value, - array_ops.concat( - [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) - return sparse_ops.sparse_concat(0, [state, reshaped_value]) - - reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) - return get_single_element.get_single_element( - dataset.apply(grouping.group_by_reducer(key_fn, reducer))) - - -class _UnbatchDataset(dataset_ops.Dataset): - """A dataset that splits the elements of its input into multiple elements.""" - - def __init__(self, input_dataset): - """See `unbatch()` for more details.""" - super(_UnbatchDataset, self).__init__() - flat_shapes = nest.flatten(input_dataset.output_shapes) - if any(s.ndims == 0 for s in flat_shapes): - raise ValueError("Cannot unbatch an input with scalar components.") - known_batch_dim = tensor_shape.Dimension(None) - for s in flat_shapes: - try: - known_batch_dim = known_batch_dim.merge_with(s[0]) - except ValueError: - raise ValueError("Cannot unbatch an input whose components have " - "different batch sizes.") - self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_dataset_ops.unbatch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return nest.map_structure(lambda s: s[1:], - self._input_dataset.output_shapes) - - @property - def output_types(self): - return self._input_dataset.output_types + return batching.dense_to_sparse_batch(batch_size, row_shape) +@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.") def unbatch(): """Splits elements of a dataset into multiple elements on the batch dimension. @@ -403,39 +92,7 @@ def unbatch(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - if not sparse.any_sparse(dataset.output_classes): - return _UnbatchDataset(dataset) - - # NOTE(mrry): We must ensure that any SparseTensors in `dataset` - # are normalized to the rank-1 dense representation, so that the - # sparse-oblivious unbatching logic will slice them - # appropriately. This leads to a somewhat inefficient re-encoding step - # for all SparseTensor components. - # TODO(mrry): Consider optimizing this in future - # if it turns out to be a bottleneck. - def normalize(arg, *rest): - if rest: - return sparse.serialize_many_sparse_tensors((arg,) + rest) - else: - return sparse.serialize_many_sparse_tensors(arg) - - normalized_dataset = dataset.map(normalize) - - # NOTE(mrry): Our `map()` has lost information about the sparseness - # of any SparseTensor components, so re-apply the structure of the - # original dataset. - restructured_dataset = _RestructuredDataset( - normalized_dataset, - dataset.output_types, - dataset.output_shapes, - dataset.output_classes, - allow_unsafe_cast=True) - return _UnbatchDataset(restructured_dataset) - - return _apply_fn + return batching.unbatch() @deprecation.deprecated( @@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size, return _apply_fn -class _DenseToSparseBatchDataset(dataset_ops.Dataset): - """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" - - def __init__(self, input_dataset, batch_size, row_shape): - """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(_DenseToSparseBatchDataset, self).__init__() - if not isinstance(input_dataset.output_types, dtypes.DType): - raise TypeError("DenseToSparseDataset requires an input whose elements " - "have a single component, whereas the input has %r." % - input_dataset.output_types) - self._input_dataset = input_dataset - self._batch_size = batch_size - self._row_shape = row_shape - - def _as_variant_tensor(self): - return gen_dataset_ops.dense_to_sparse_batch_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._batch_size, - row_shape=convert.partial_shape_to_tensor(self._row_shape), - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return sparse_tensor.SparseTensor - - @property - def output_shapes(self): - return tensor_shape.vector(None).concatenate(self._row_shape) - - @property - def output_types(self): - return self._input_dataset.output_types - - -class _RestructuredDataset(dataset_ops.Dataset): - """An internal helper for changing the structure and shape of a dataset.""" - - def __init__(self, - dataset, - output_types, - output_shapes=None, - output_classes=None, - allow_unsafe_cast=False): - """Creates a new dataset with the given output types and shapes. - - The given `dataset` must have a structure that is convertible: - * `dataset.output_types` must be the same as `output_types` module nesting. - * Each shape in `dataset.output_shapes` must be compatible with each shape - in `output_shapes` (if given). - - Note: This helper permits "unsafe casts" for shapes, equivalent to using - `tf.Tensor.set_shape()` where domain-specific knowledge is available. - - Args: - dataset: A `Dataset` object. - output_types: A nested structure of `tf.DType` objects. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. - If omitted, the shapes will be inherited from `dataset`. - output_classes: (Optional.) A nested structure of class types. - If omitted, the class types will be inherited from `dataset`. - allow_unsafe_cast: (Optional.) If `True`, the caller may switch the - reported output types and shapes of the restructured dataset, e.g. to - switch a sparse tensor represented as `tf.variant` to its user-visible - type and shape. - - Raises: - ValueError: If either `output_types` or `output_shapes` is not compatible - with the structure of `dataset`. - """ - super(_RestructuredDataset, self).__init__() - self._input_dataset = dataset - - if not allow_unsafe_cast: - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(dataset.output_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have " - "output types %r" % (dataset.output_types, output_types)) - - self._output_types = output_types - - if output_shapes is None: - # Inherit shapes from the original `dataset`. - self._output_shapes = nest.pack_sequence_as(output_types, - nest.flatten( - dataset.output_shapes)) - else: - if not allow_unsafe_cast: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(dataset.output_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" % (dataset.output_shapes, - output_shapes)) - self._output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - if output_classes is None: - # Inherit class types from the original `dataset`. - self._output_classes = nest.pack_sequence_as(output_types, - nest.flatten( - dataset.output_classes)) - else: - self._output_classes = output_classes - - def _as_variant_tensor(self): - return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - - @property - def output_classes(self): - return self._output_classes - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - +# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()` +# function is available in the core. def assert_element_shape(expected_shapes): """Assert the shape of this `Dataset`. @@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes): def _apply_fn(dataset): output_shapes = _merge_output_shapes(dataset.output_shapes, expected_shapes) - return _RestructuredDataset( + # pylint: disable=protected-access + return batching._RestructuredDataset( dataset.map(_check_shape), dataset.output_types, output_shapes=output_shapes, @@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes): return _apply_fn -class _MapAndBatchDataset(dataset_ops.MapDataset): - """A `Dataset` that maps a function over a batch of elements.""" - - def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, - drop_remainder): - """See `Dataset.map()` for details.""" - super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size_t = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_calls_t = ops.convert_to_tensor( - num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") - self._drop_remainder_t = ops.convert_to_tensor( - drop_remainder, dtype=dtypes.bool, name="drop_remainder") - - self._batch_size = batch_size - self._drop_remainder = drop_remainder - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset_v2( - input_resource, - self._map_func.captured_inputs, - f=self._map_func, - batch_size=self._batch_size_t, - num_parallel_calls=self._num_parallel_calls_t, - drop_remainder=self._drop_remainder_t, - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_shapes(self): - dim = self._batch_size if self._drop_remainder else None - return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(dim).concatenate(s) - for s in nest.flatten(self._output_shapes) - ]) - - @property - def output_types(self): - return self._output_types - - +@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.") def map_and_batch(map_func, batch_size, num_parallel_batches=None, @@ -779,17 +268,5 @@ def map_and_batch(map_func, ValueError: If both `num_parallel_batches` and `num_parallel_calls` are specified. """ - - if num_parallel_batches is None and num_parallel_calls is None: - num_parallel_calls = batch_size - elif num_parallel_batches is not None and num_parallel_calls is None: - num_parallel_calls = batch_size * num_parallel_batches - elif num_parallel_batches is not None and num_parallel_calls is not None: - raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " - "arguments are mutually exclusive.") - - def _apply_fn(dataset): - return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_calls, drop_remainder) - - return _apply_fn + return batching.map_and_batch(map_func, batch_size, num_parallel_batches, + drop_remainder, num_parallel_calls) diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 6ef65f9624601286691505a795a86dd6226eead1..4ff5bf3e39dc2c9313b7d47d1ef965ebb22afc06 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -17,13 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import scan_ops - -from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.experimental.ops import counter from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.") def Counter(start=0, step=1, dtype=dtypes.int64): """Creates a `Dataset` that counts from `start` in steps of size `step`. @@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64): Returns: A `Dataset` of scalar `dtype` elements. """ - with ops.name_scope("counter"): - start = ops.convert_to_tensor(start, dtype=dtype, name="start") - step = ops.convert_to_tensor(step, dtype=dtype, name="step") - return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( - scan_ops.scan(start, lambda state, _: (state + step, state))) + return counter.Counter(start, step, dtype) diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index 490281e0d2da7a454a2f63f95753c7c436b87a76..a21da4d3eca508f2af9bac49d57fb0c4b08f3be0 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes +from tensorflow.python.data.experimental.ops import enumerate_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.enumerate_dataset(...)`.") def enumerate_dataset(start=0): """A transformation that enumerate the elements of a dataset. @@ -49,10 +50,4 @@ def enumerate_dataset(start=0): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max - return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value), - dataset)) - - return _apply_fn + return enumerate_ops.enumerate_dataset(start) diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index b4a7521e0875089c39ac7aa8b7b49e44feb2b4ad..0559a2e09cce43cf16e88dbe20dba2c46db4c979 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,11 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.experimental.ops import error_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.") def ignore_errors(): """Creates a `Dataset` from another `Dataset` and silently ignores any errors. @@ -44,34 +44,4 @@ def ignore_errors(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _IgnoreErrorsDataset(dataset) - - return _apply_fn - - -class _IgnoreErrorsDataset(dataset_ops.Dataset): - """A `Dataset` that silently ignores errors when computing its input.""" - - def __init__(self, input_dataset): - """See `Dataset.ignore_errors()` for details.""" - super(_IgnoreErrorsDataset, self).__init__() - self._input_dataset = input_dataset - - def _as_variant_tensor(self): - return gen_dataset_ops.ignore_errors_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types + return error_ops.ignore_errors() diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a6713b017afa315edec9389d0a6c1c7135e6aeb9..58ad9eea903c42981b8fd083ed1c39421c58189f 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,13 +19,13 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.get_single_element(...)`.") def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. @@ -61,18 +61,10 @@ def get_single_element(dataset): InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - - nested_ret = nest.pack_sequence_as( - dataset.output_types, gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(dataset))) - return sparse.deserialize_sparse_tensors( - nested_ret, dataset.output_types, dataset.output_shapes, - dataset.output_classes) + return experimental_get_single_element.get_single_element(dataset) +@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.") def reduce_dataset(dataset, reducer): """Returns the result of reducing the `dataset` using `reducer`. @@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - # The sentinel dataset is used in case the reduced dataset is empty. - sentinel_dataset = dataset_ops.Dataset.from_tensors( - reducer.finalize_func(reducer.init_func(np.int64(0)))) - reduced_dataset = dataset.apply( - grouping.group_by_reducer(lambda x: np.int64(0), reducer)) - - return get_single_element( - reduced_dataset.concatenate(sentinel_dataset).take(1)) + return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 099e10db921b78fc9fa3bcf73979ae6c33bc1972..a99dc2f29ae4c9d47c21afd83f49bf4eb89eca18 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,20 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.data.experimental.ops import grouping +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.group_by_reducer(...)`.") def group_by_reducer(key_func, reducer): """A transformation that groups elements and performs a reduction. @@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _GroupByReducerDataset(dataset, key_func, reducer) - - return _apply_fn + return grouping.group_by_reducer(key_func, reducer) +@deprecation.deprecated(None, + "Use `tf.data.experimental.group_by_window(...)`.") def group_by_window(key_func, reduce_func, window_size=None, @@ -98,27 +88,12 @@ def group_by_window(key_func, ValueError: if neither or both of {`window_size`, `window_size_func`} are passed. """ - if (window_size is not None and window_size_func or - not (window_size is not None or window_size_func)): - raise ValueError("Must pass either window_size or window_size_func.") - - if window_size is not None: - - def constant_window_func(unused_key): - return ops.convert_to_tensor(window_size, dtype=dtypes.int64) - - window_size_func = constant_window_func - - assert window_size_func is not None - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) - - return _apply_fn + return grouping.group_by_window(key_func, reduce_func, window_size, + window_size_func) +@deprecation.deprecated( + None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.") def bucket_by_sequence_length(element_length_func, bucket_boundaries, bucket_batch_sizes, @@ -163,336 +138,12 @@ def bucket_by_sequence_length(element_length_func, Raises: ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. """ - with ops.name_scope("bucket_by_seq_length"): - if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): - raise ValueError( - "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") - - batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) - - def element_to_bucket_id(*args): - """Return int64 id of the length bucket for this element.""" - seq_length = element_length_func(*args) - - boundaries = list(bucket_boundaries) - buckets_min = [np.iinfo(np.int32).min] + boundaries - buckets_max = boundaries + [np.iinfo(np.int32).max] - conditions_c = math_ops.logical_and( - math_ops.less_equal(buckets_min, seq_length), - math_ops.less(seq_length, buckets_max)) - bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) - - return bucket_id - - def window_size_fn(bucket_id): - # The window size is set to the batch size for this bucket - window_size = batch_sizes[bucket_id] - return window_size - - def make_padded_shapes(shapes, none_filler=None): - padded = [] - for shape in nest.flatten(shapes): - shape = tensor_shape.TensorShape(shape) - shape = [ - none_filler if d.value is None else d - for d in shape - ] - padded.append(shape) - return nest.pack_sequence_as(shapes, padded) - - def batching_fn(bucket_id, grouped_dataset): - """Batch elements in dataset.""" - batch_size = window_size_fn(bucket_id) - if no_padding: - return grouped_dataset.batch(batch_size) - none_filler = None - if pad_to_bucket_boundary: - err_msg = ("When pad_to_bucket_boundary=True, elements must have " - "length < max(bucket_boundaries).") - check = check_ops.assert_less( - bucket_id, - constant_op.constant(len(bucket_batch_sizes) - 1, - dtype=dtypes.int64), - message=err_msg) - with ops.control_dependencies([check]): - boundaries = constant_op.constant(bucket_boundaries, - dtype=dtypes.int64) - bucket_boundary = boundaries[bucket_id] - none_filler = bucket_boundary - 1 - shapes = make_padded_shapes( - padded_shapes or grouped_dataset.output_shapes, - none_filler=none_filler) - return grouped_dataset.padded_batch(batch_size, shapes, padding_values) - - def _apply_fn(dataset): - return dataset.apply( - group_by_window(element_to_bucket_id, batching_fn, - window_size_func=window_size_fn)) - - return _apply_fn - - -def _map_x_dataset(map_func): - """A transformation that maps `map_func` across its input. - - This transformation is similar to `tf.data.Dataset.map`, but in addition to - supporting dense and sparse tensor inputs, it also supports dataset inputs. - - Args: - map_func: A function mapping a nested structure of tensors and/or datasets - (having shapes and types defined by `self.output_shapes` and - `self.output_types`) to another nested structure of tensors and/or - datasets. - - Returns: - Dataset: A `Dataset`. - """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return _MapXDataset(dataset, map_func) - - return _apply_fn - - -def window_dataset(window_size): - """A transformation that creates window datasets from the input dataset. - - The resulting datasets will contain `window_size` elements (or - `N % window_size` for the last dataset if `window_size` does not divide the - number of input elements `N` evenly). - - Args: - window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of the input dataset to combine into a window. - - Returns: - Dataset: A `Dataset`. - """ - - def _apply_fn(dataset): - return _WindowDataset(dataset, window_size) - - return _apply_fn - - -class _GroupByReducerDataset(dataset_ops.Dataset): - """A `Dataset` that groups its input and performs a reduction.""" - - def __init__(self, input_dataset, key_func, reducer): - """See `group_by_reducer()` for details.""" - super(_GroupByReducerDataset, self).__init__() - - self._input_dataset = input_dataset - - self._make_key_func(key_func, input_dataset) - self._make_init_func(reducer.init_func) - self._make_reduce_func(reducer.reduce_func, input_dataset) - self._make_finalize_func(reducer.finalize_func) - - def _make_key_func(self, key_func, input_dataset): - """Make wrapping Defun for key_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - key_func, "tf.contrib.data.group_by_reducer()", input_dataset) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" - % (wrapped_func.output_types, wrapped_func.output_shapes)) - self._key_func = wrapped_func.function + return grouping.bucket_by_sequence_length( + element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes, + padding_values, pad_to_bucket_boundary, no_padding) - def _make_init_func(self, init_func): - """Make wrapping Defun for init_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - init_func, "tf.contrib.data.group_by_reducer()", - input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), - input_types=dtypes.int64) - self._init_func = wrapped_func.function - self._state_classes = wrapped_func.output_classes - self._state_shapes = wrapped_func.output_shapes - self._state_types = wrapped_func.output_types - def _make_reduce_func(self, reduce_func, input_dataset): - """Make wrapping Defun for reduce_func.""" - - # Iteratively rerun the reduce function until reaching a fixed point on - # `self._state_shapes`. - need_to_rerun = True - while need_to_rerun: - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - reduce_func, "tf.contrib.data.group_by_reducer()", - input_classes=(self._state_classes, input_dataset.output_classes), - input_shapes=(self._state_shapes, input_dataset.output_shapes), - input_types=(self._state_types, input_dataset.output_types), - add_to_graph=False) - - # Extract and validate class information from the returned values. - for new_state_class, state_class in zip( - nest.flatten(wrapped_func.output_classes), - nest.flatten(self._state_classes)): - if not issubclass(new_state_class, state_class): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, wrapped_func.output_classes)) - - # Extract and validate type information from the returned values. - for new_state_type, state_type in zip( - nest.flatten(wrapped_func.output_types), - nest.flatten(self._state_types)): - if new_state_type != state_type: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, wrapped_func.output_types)) - - # Extract shape information from the returned values. - flat_state_shapes = nest.flatten(self._state_shapes) - flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) - weakened_state_shapes = [ - original.most_specific_compatible_shape(new) - for original, new in zip(flat_state_shapes, flat_new_state_shapes) - ] - - need_to_rerun = False - for original_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if original_shape.ndims is not None and ( - weakened_shape.ndims is None or - original_shape.as_list() != weakened_shape.as_list()): - need_to_rerun = True - break - - if need_to_rerun: - self._state_shapes = nest.pack_sequence_as(self._state_shapes, - weakened_state_shapes) - - self._reduce_func = wrapped_func.function - self._reduce_func.add_to_graph(ops.get_default_graph()) - - def _make_finalize_func(self, finalize_func): - """Make wrapping Defun for finalize_func.""" - wrapped_func = dataset_ops.StructuredFunctionWrapper( - finalize_func, "tf.contrib.data.group_by_reducer()", - input_classes=self._state_classes, input_shapes=self._state_shapes, - input_types=self._state_types) - self._finalize_func = wrapped_func.function - self._output_classes = wrapped_func.output_classes - self._output_shapes = wrapped_func.output_shapes - self._output_types = wrapped_func.output_types - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.group_by_reducer_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.captured_inputs, - self._init_func.captured_inputs, - self._reduce_func.captured_inputs, - self._finalize_func.captured_inputs, - key_func=self._key_func, - init_func=self._init_func, - reduce_func=self._reduce_func, - finalize_func=self._finalize_func, - **dataset_ops.flat_structure(self)) - - -class _GroupByWindowDataset(dataset_ops.Dataset): - """A `Dataset` that groups its input and performs a windowed reduction.""" - - def __init__(self, input_dataset, key_func, reduce_func, window_size_func): - """See `group_by_window()` for details.""" - super(_GroupByWindowDataset, self).__init__() - - self._input_dataset = input_dataset - - self._make_key_func(key_func, input_dataset) - self._make_reduce_func(reduce_func, input_dataset) - self._make_window_size_func(window_size_func) - - def _make_window_size_func(self, window_size_func): - """Make wrapping Defun for window_size_func.""" - def window_size_func_wrapper(key): - return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) - wrapped_func = dataset_ops.StructuredFunctionWrapper( - window_size_func_wrapper, "tf.contrib.data.group_by_window()", - input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), - input_types=dtypes.int64) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`window_size_func` must return a single tf.int64 scalar tensor.") - self._window_size_func = wrapped_func.function - - def _make_key_func(self, key_func, input_dataset): - """Make wrapping Defun for key_func.""" - def key_func_wrapper(*args): - return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) - wrapped_func = dataset_ops.StructuredFunctionWrapper( - key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) - if not ( - wrapped_func.output_types == dtypes.int64 and - wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): - raise ValueError( - "`key_func` must return a single tf.int64 scalar tensor.") - self._key_func = wrapped_func.function - - def _make_reduce_func(self, reduce_func, input_dataset): - """Make wrapping Defun for reduce_func.""" - nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access - wrapped_func = dataset_ops.StructuredFunctionWrapper( - reduce_func, "tf.contrib.data.reduce_by_window()", - input_classes=(ops.Tensor, nested_dataset), - input_shapes=(tensor_shape.scalar(), nested_dataset), - input_types=(dtypes.int64, nested_dataset), - experimental_nested_dataset_support=True) - if not isinstance( - wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = wrapped_func.output_classes.output_classes - self._output_types = wrapped_func.output_types.output_types - self._output_shapes = wrapped_func.output_shapes.output_shapes - self._reduce_func = wrapped_func.function - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.group_by_window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._key_func.captured_inputs, - self._reduce_func.captured_inputs, - self._window_size_func.captured_inputs, - key_func=self._key_func, - reduce_func=self._reduce_func, - window_size_func=self._window_size_func, - **dataset_ops.flat_structure(self)) - - -class Reducer(object): +class Reducer(grouping.Reducer): """A reducer is used for reducing a set of elements. A reducer is represented as a tuple of the three functions: @@ -501,101 +152,6 @@ class Reducer(object): 3) finalization function: state => result """ + @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.") def __init__(self, init_func, reduce_func, finalize_func): - self._init_func = init_func - self._reduce_func = reduce_func - self._finalize_func = finalize_func - - @property - def init_func(self): - return self._init_func - - @property - def reduce_func(self): - return self._reduce_func - - @property - def finalize_func(self): - return self._finalize_func - - -class _MapXDataset(dataset_ops.Dataset): - """A `Dataset` that maps a function over elements in its input.""" - - def __init__(self, input_dataset, map_func): - """See `map_x_dataset()` for details.""" - super(_MapXDataset, self).__init__() - self._input_dataset = input_dataset - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - map_func, - "tf.contrib.data.map_x_dataset()", - input_dataset, - experimental_nested_dataset_support=True) - self._output_classes = wrapped_func.output_classes - self._output_shapes = wrapped_func.output_shapes - self._output_types = wrapped_func.output_types - self._map_func = wrapped_func.function - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return gen_dataset_ops.map_dataset( - input_t, - self._map_func.captured_inputs, - f=self._map_func, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - -class _WindowDataset(dataset_ops.Dataset): - """A dataset that creates window datasets from the input elements.""" - - def __init__(self, input_dataset, window_size): - """See `window_dataset()` for more details.""" - super(_WindowDataset, self).__init__() - self._input_dataset = input_dataset - self._window_size = ops.convert_to_tensor( - window_size, dtype=dtypes.int64, name="window_size") - self._output_classes = nest.pack_sequence_as( - input_dataset.output_classes, - [ - dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access - output_classes=output_class, - output_shapes=output_shape, - output_types=output_type) - for output_class, output_shape, output_type in zip( - nest.flatten(input_dataset.output_classes), - nest.flatten(input_dataset.output_shapes), - nest.flatten(input_dataset.output_types)) - ]) - self._output_shapes = self._output_classes - self._output_types = self._output_classes - - def _as_variant_tensor(self): - return gen_dataset_ops.window_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._window_size, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types + super(Reducer, self).__init__(init_func, reduce_func, finalize_func) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 92d4251a864dae7d5725b0f177b54c5cbcc14aec..f50da4d429f715418a95cf177a3f4b5d273c8844 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,21 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import stateless -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.contrib.data.python.ops import random_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.parallel_interleave(...)`.") def parallel_interleave(map_func, cycle_length, block_length=1, @@ -81,12 +72,9 @@ def parallel_interleave(map_func, A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return readers.ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy, - buffer_output_elements, prefetch_input_elements) - - return _apply_fn + return interleave_ops.parallel_interleave( + map_func, cycle_length, block_length, sloppy, buffer_output_elements, + prefetch_input_elements) @deprecation.deprecated( @@ -140,58 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return readers.ParallelInterleaveDataset( - dataset, - map_func, - cycle_length, - block_length, - sloppy=True, - buffer_output_elements=None, - prefetch_input_elements=None) - - return _apply_fn - - -class _DirectedInterleaveDataset(dataset_ops.Dataset): - """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" - - def __init__(self, selector_input, data_inputs): - self._selector_input = selector_input - self._data_inputs = list(data_inputs) - - for data_input in data_inputs[1:]: - if (data_input.output_types != data_inputs[0].output_types or - data_input.output_classes != data_inputs[0].output_classes): - raise TypeError("All datasets must have the same type and class.") - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.directed_interleave_dataset( - self._selector_input._as_variant_tensor(), - [data_input._as_variant_tensor() for data_input in self._data_inputs], - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_classes(self): - return self._data_inputs[0].output_classes - - @property - def output_shapes(self): - ret = self._data_inputs[0].output_shapes - for data_input in self._data_inputs[1:]: - ret = nest.pack_sequence_as(ret, [ - ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( - nest.flatten(ret), nest.flatten(data_input.output_shapes)) - ]) - return ret - - @property - def output_types(self): - return self._data_inputs[0].output_types + return interleave_ops.parallel_interleave( + map_func, cycle_length, block_length, sloppy=True) +@deprecation.deprecated(None, + "Use `tf.data.experimental.sample_from_datasets(...)`.") def sample_from_datasets(datasets, weights=None, seed=None): """Samples elements at random from the datasets in `datasets`. @@ -215,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None): ValueError: If the `weights` argument is specified and does not match the length of the `datasets` element. """ - num_datasets = len(datasets) - if not isinstance(weights, dataset_ops.Dataset): - if weights is None: - # Select inputs with uniform probability. - logits = [[1.0] * num_datasets] - - else: - # Use the given `weights` as the probability of choosing the respective - # input. - weights = ops.convert_to_tensor(weights, name="weights") - if weights.dtype not in (dtypes.float32, dtypes.float64): - raise TypeError("`weights` must be convertible to a tensor of " - "`tf.float32` or `tf.float64` elements.") - if not weights.shape.is_compatible_with([num_datasets]): - raise ValueError( - "`weights` must be a vector of length `len(datasets)`.") - - # The `stateless_multinomial()` op expects log-probabilities, as opposed - # to weights. - logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) - - # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it - # is a `Dataset`, it is possible that evaluating it has a side effect the - # user depends on. - if len(datasets) == 1: - return datasets[0] - - def select_dataset_constant_logits(seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - - selector_input = dataset_ops.MapDataset( - random_ops.RandomDataset(seed).batch(2), - select_dataset_constant_logits, - use_inter_op_parallelism=False) - - else: - # Use each element of the given `weights` dataset as the probability of - # choosing the respective input. - - # The `stateless_multinomial()` op expects log-probabilities, as opposed to - # weights. - logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) - - def select_dataset_varying_logits(logits, seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - - logits_and_seeds = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2))) - selector_input = dataset_ops.MapDataset( - logits_and_seeds, - select_dataset_varying_logits, - use_inter_op_parallelism=False) - - return _DirectedInterleaveDataset(selector_input, datasets) + return interleave_ops.sample_from_datasets(datasets, weights, seed) +@deprecation.deprecated(None, + "Use `tf.data.experimental.choose_from_datasets(...)`.") def choose_from_datasets(datasets, choice_dataset): """Creates a dataset that deterministically chooses elements from `datasets`. @@ -308,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset): TypeError: If the `datasets` or `choice_dataset` arguments have the wrong type. """ - if not (choice_dataset.output_types == dtypes.int64 - and choice_dataset.output_shapes.is_compatible_with( - tensor_shape.scalar()) - and choice_dataset.output_classes == ops.Tensor): - raise TypeError("`choice_dataset` must be a dataset of scalar " - "`tf.int64` tensors.") - return _DirectedInterleaveDataset(choice_dataset, datasets) + return interleave_ops.choose_from_datasets(datasets, choice_dataset) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 18515e21edfe0449514ab4f21683a600eaf48910..48c325c86f74b4c922e70a33212b49196b34e357 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,15 +16,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import checkpoint_management -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training import session_run_hook +from tensorflow.python.data.experimental.ops import iterator_ops +from tensorflow.python.util import deprecation + +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.") def make_saveable_from_iterator(iterator): """Returns a SaveableObject for saving/restore iterator state using Saver. @@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator): Note: Not all iterators support checkpointing yet. Attempting to save the state of an unsupported iterator will throw an error. """ - return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access - - -class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" + return iterator_ops.make_saveable_from_iterator(iterator) - def __init__(self, iterator_resource): - serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) - specs = [ - saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") - ] - super(_Saveable, self).__init__(iterator_resource, specs, - iterator_resource.name) - def restore(self, restored_tensors, unused_restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) - - -class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): +class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook): """Checkpoints input pipeline state every N steps or seconds. This hook saves the state of the iterators in the `Graph` so that when @@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): collector when building the eval graph. """ + @deprecation.deprecated( + None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.") def __init__(self, estimator): - """Initializes a `CheckpointInputPipelineHook`. - - Args: - estimator: Estimator. - - Raises: - ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of saver or scaffold should be set. - """ - # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or - # of the form "input__.ckpt" for distributed pipelines. - # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is - # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix - # to be different to avoid conflicts with the model checkpoint. - - # pylint: disable=protected-access - checkpoint_prefix = "input" - if estimator._config.num_worker_replicas > 1: - # Distributed setting. - suffix = "_{}_{}".format(estimator._config.task_type, - estimator._config.task_id) - checkpoint_prefix += suffix - # pylint: enable=protected-access - - # We use a composition paradigm instead of inheriting from - # `CheckpointSaverHook` because `Estimator` does an `isinstance` check - # to check whether a `CheckpointSaverHook` is already present in the list - # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` - # would thwart this behavior. This hook checkpoints *only the iterators* - # and not the graph variables. - self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( - estimator.model_dir, - save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access - save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access - checkpoint_basename=checkpoint_prefix + ".ckpt") - - # Name for the protocol buffer file that will contain the list of most - # recent checkpoints stored as a `CheckpointState` protocol buffer. - # This file, kept in the same directory as the checkpoint files, is - # automatically managed by the `Saver` to keep track of recent checkpoints. - # The default name used by the `Saver` for this file is "checkpoint". Here - # we use the name "checkpoint_" so that in case the - # `checkpoint_dir` is the same as the model checkpoint directory, there are - # no conflicts during restore. - self._latest_filename = "checkpoint_" + checkpoint_prefix - self._first_run = True - - def begin(self): - # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` - # collection if no `Saver` or `Scaffold` is provided. - # pylint: disable=protected-access - if (self._checkpoint_saver_hook._saver is None and - self._checkpoint_saver_hook._scaffold is None): - iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) - saveables = [_Saveable(i) for i in iterators] - self._checkpoint_saver_hook._saver = _CustomSaver(saveables, - self._latest_filename) - # pylint: enable=protected-access - self._checkpoint_saver_hook.begin() - - def _restore_or_save_initial_ckpt(self, session): - # Ideally this should be run in after_create_session but is not for the - # following reason: - # Currently there is no way of enforcing an order of running the - # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` - # is run *after* this hook. That is troublesome because - # 1. If a checkpoint exists and this hook restores it, the initializer hook - # will override it. - # 2. If no checkpoint exists, this hook will try to save an initialized - # iterator which will result in an exception. - # - # As a temporary fix we enter the following implicit contract between this - # hook and the _DatasetInitializerHook. - # 1. The _DatasetInitializerHook initializes the iterator in the call to - # after_create_session. - # 2. This hook saves the iterator on the first call to `before_run()`, which - # is guaranteed to happen after `after_create_session()` of all hooks - # have been run. - - # Check if there is an existing checkpoint. If so, restore from it. - # pylint: disable=protected-access - latest_checkpoint_path = checkpoint_management.latest_checkpoint( - self._checkpoint_saver_hook._checkpoint_dir, - latest_filename=self._latest_filename) - if latest_checkpoint_path: - self._checkpoint_saver_hook._get_saver().restore(session, - latest_checkpoint_path) - else: - # The checkpoint saved here is the state at step "global_step". - # Note: We do not save the GraphDef or MetaGraphDef here. - global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) - self._checkpoint_saver_hook._save(session, global_step) - self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) - # pylint: enable=protected-access - - def before_run(self, run_context): - if self._first_run: - self._restore_or_save_initial_ckpt(run_context.session) - self._first_run = False - return self._checkpoint_saver_hook.before_run(run_context) - - def after_run(self, run_context, run_values): - self._checkpoint_saver_hook.after_run(run_context, run_values) - - def end(self, session): - self._checkpoint_saver_hook.end(session) - - -class _CustomSaver(saver_lib.Saver): - """`Saver` with a different default `latest_filename`. - - This is used in the `CheckpointInputPipelineHook` to avoid conflicts with - the model ckpt saved by the `CheckpointSaverHook`. - """ - - def __init__(self, var_list, latest_filename): - super(_CustomSaver, self).__init__(var_list) - self._latest_filename = latest_filename - - def save(self, - sess, - save_path, - global_step=None, - latest_filename=None, - meta_graph_suffix="meta", - write_meta_graph=True, - write_state=True, - strip_default_attrs=False): - return super(_CustomSaver, self).save( - sess, save_path, global_step, latest_filename or self._latest_filename, - meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) + super(CheckpointInputPipelineHook, self).__init__(estimator) diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py index 2701605e641b190852bb9934ce83f7fc3e90ff15..3aeee9d8e42dce5af133afeeab4a8c97e50d5571 100644 --- a/tensorflow/contrib/data/python/ops/parsing_ops.py +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -17,92 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import parsing_ops +from tensorflow.python.data.experimental.ops import parsing_ops +from tensorflow.python.util import deprecation -class _ParseExampleDataset(dataset_ops.Dataset): - """A `Dataset` that parses `example` dataset into a `dict` dataset.""" - - def __init__(self, input_dataset, features, num_parallel_calls): - super(_ParseExampleDataset, self).__init__() - self._input_dataset = input_dataset - if not all(types == dtypes.string - for types in nest.flatten(input_dataset.output_types)): - raise TypeError("Input dataset should be a dataset of vectors of strings") - self._num_parallel_calls = num_parallel_calls - # pylint: disable=protected-access - self._features = parsing_ops._prepend_none_dimension(features) - # sparse_keys and dense_keys come back sorted here. - (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, - dense_shapes) = parsing_ops._features_to_raw_params( - self._features, [ - parsing_ops.VarLenFeature, parsing_ops.SparseFeature, - parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature - ]) - # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. - (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, - dense_shape_as_shape) = parsing_ops._process_raw_parameters( - None, dense_defaults, sparse_keys, sparse_types, dense_keys, - dense_types, dense_shapes) - # pylint: enable=protected-access - self._sparse_keys = sparse_keys - self._sparse_types = sparse_types - self._dense_keys = dense_keys - self._dense_defaults = dense_defaults_vec - self._dense_shapes = dense_shapes - self._dense_types = dense_types - dense_output_shapes = [ - self._input_dataset.output_shapes.concatenate(shape) - for shape in dense_shape_as_shape - ] - sparse_output_shapes = [ - self._input_dataset.output_shapes.concatenate([None]) - for _ in range(len(sparse_keys)) - ] - - self._output_shapes = dict( - zip(self._dense_keys + self._sparse_keys, - dense_output_shapes + sparse_output_shapes)) - self._output_types = dict( - zip(self._dense_keys + self._sparse_keys, - self._dense_types + self._sparse_types)) - self._output_classes = dict( - zip(self._dense_keys + self._sparse_keys, - [ops.Tensor for _ in range(len(self._dense_defaults))] + - [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) - ])) - - def _as_variant_tensor(self): - return gen_dataset_ops.parse_example_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._num_parallel_calls, - self._dense_defaults, - self._sparse_keys, - self._dense_keys, - self._sparse_types, - self._dense_shapes, - **dataset_ops.flat_structure(self)) - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - @property - def output_classes(self): - return self._output_classes - - -# TODO(b/111553342): add arguments names and example names as well. +@deprecation.deprecated( + None, "Use `tf.data.experimental.parse_example_dataset(...)`.") def parse_example_dataset(features, num_parallel_calls=1): """A transformation that parses `Example` protos into a `dict` of tensors. @@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1): Raises: ValueError: if features argument is None. """ - if features is None: - raise ValueError("Missing: features was %s." % features) - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls) - if any([ - isinstance(feature, parsing_ops.SparseFeature) - for _, feature in features.items() - ]): - # pylint: disable=protected-access - # pylint: disable=g-long-lambda - out_dataset = out_dataset.map( - lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features( - features, x), num_parallel_calls=num_parallel_calls) - return out_dataset - - return _apply_fn + return parsing_ops.parse_example_dataset(features, num_parallel_calls) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 5222011d045efd9a64b4e89b248303cffbcb0b37..adfb390cd9a6b159fe3887666993c6e9d6c758d8 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,320 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import warnings - -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.eager import context -from tensorflow.python.framework import device as framework_device -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops -from tensorflow.python.ops import resource_variable_ops - - -def function_buffering_resource(string_arg, - target_device, - f, - buffer_size, - output_types, - container="", - shared_name=None, - name=None): - """Creates a FunctionBufferingResource. - - A FunctionBufferingResource fills up a buffer by calling a function `f` on - `target_device`. `f` should take in only a single string argument as input. - - Args: - string_arg: The single string argument to the function. - target_device: The device to run `f` on. - f: The function to be executed. - buffer_size: Size of the buffer to be populated. - output_types: The output types generated by the function. - container: (Optional) string. Defaults to "". - shared_name: (Optional) string. - name: (Optional) string to name the op. - - Returns: - Handle to a FunctionBufferingResource. - """ - if shared_name is None: - shared_name = "" - return gen_dataset_ops.function_buffering_resource( - string_arg=string_arg, - target_device=target_device, - shared_name=shared_name, - f=f, - buffer_size=buffer_size, - container=container, - name=name, - output_types=output_types) - - -def function_buffering_resource_get_next(function_buffer_resource, - output_types, - name=None): - return gen_dataset_ops.function_buffering_resource_get_next( - function_buffer_resource=function_buffer_resource, - output_types=output_types, - name=name) - - -def function_buffering_resource_reset(function_buffer_resource, name=None): - return gen_dataset_ops.function_buffering_resource_reset( - function_buffer_resource=function_buffer_resource, name=name) - - -# pylint: disable=protected-access -class _PrefetchToDeviceIterator(object): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset - one_shot: If true, we make a one shot iterator that's already initialized. - device: A fully specified device string where we want to prefetch to - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be - shared under the given name across multiple sessions that share the - same devices (e.g. when using a remote server). - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - one_shot, - device, - buffer_size, - shared_name=None): - self._input_dataset = input_dataset - self._get_next_call_count = 0 - self._one_shot = one_shot - if shared_name is None: - shared_name = "" - - if self._one_shot: - self._input_iterator = input_dataset.make_one_shot_iterator() - else: - self._input_iterator = iterator_ops.Iterator.from_structure( - self._input_dataset.output_types, self._input_dataset.output_shapes, - shared_name, self._input_dataset.output_classes) - input_iterator_handle = self._input_iterator.string_handle() - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self._input_iterator.output_types, - self._input_iterator.output_shapes, - self._input_iterator.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - iterator_device = gen_dataset_ops.iterator_get_device( - self._input_iterator._iterator_resource) - - with ops.device(device): - self._buffering_resource = function_buffering_resource( - f=_prefetch_fn, - target_device=iterator_device, - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes))) - - if not self._one_shot: - reset_op = function_buffering_resource_reset(self._buffering_resource) - with ops.control_dependencies([reset_op]): - self._initializer = self._input_iterator.make_initializer( - self._input_dataset) - - def get_next(self, name=None): - """See `tf.data.Iterator.get_next`.""" - self._get_next_call_count += 1 - if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: - warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - - flat_ret = gen_dataset_ops.function_buffering_resource_get_next( - self._buffering_resource, - output_types=nest.flatten(sparse.as_dense_types( - self.output_types, self.output_classes)), name=name) - - ret = sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self.output_types, flat_ret), - self.output_types, self.output_shapes, self.output_classes) - - for tensor, shape in zip( - nest.flatten(ret), nest.flatten(self.output_shapes)): - if isinstance(tensor, ops.Tensor): - tensor.set_shape(shape) - - return ret - - @property - def initializer(self): - if self._one_shot: - raise NotImplementedError("Can't initialize a one_shot_iterator") - return self._initializer - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset - one_shot: If true, we make a one shot iterator that's already initialized. - device: A fully specified device string where we want to prefetch to - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be - shared under the given name across multiple sessions that share the - same devices (e.g. when using a remote server). - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - device, - buffer_size): - with ops.device("/device:CPU:0"): - super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) - input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle( - self._resource) - - self._device = device - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self.output_types, self.output_shapes, self.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - _prefetch_fn.add_to_graph(None) - - with ops.device(device): - self._buffering_resource = function_buffering_resource( - f=_prefetch_fn, - output_types=self._flat_output_types, - target_device=gen_dataset_ops.iterator_get_device(self._resource), - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=iterator_ops._generate_shared_name( - "function_buffer_resource")) - - def _next_internal(self): - """Returns a nested structure of `tf.Tensor`s containing the next element. - """ - # This runs in sync mode as iterators use an error status to communicate - # that there is no more data to iterate over. - # TODO(b/77291417): Fix - with context.execution_mode(context.SYNC): - with ops.device(self._device): - ret = gen_dataset_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffering_resource, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) -# pylint: enable=protected-access - - -class _PrefetchToDeviceDataset(dataset_ops.Dataset): - """A `Dataset` whose iterator prefetches elements to another device.""" - - def __init__(self, input_dataset, device, buffer_size): - self._input_dataset = input_dataset - self._device = device - self._buffer_size = buffer_size if buffer_size is not None else 1 - - # The static analysis cannot tell that the eager iterator's superclass has - # a `next()` method. - # pylint: disable=non-iterator-returned - def __iter__(self): - """Creates an `Iterator` for enumerating the elements of this dataset. - - The returned iterator implements the Python iterator protocol and therefore - can only be used in eager mode. - - Returns: - An `Iterator` over the elements of this dataset. - - Raises: - RuntimeError: If eager execution is enabled. - """ - if context.executing_eagerly(): - return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, - self._buffer_size) - else: - raise RuntimeError("dataset.__iter__() is only supported when eager " - "execution is enabled.") - # pylint: enable=non-iterator-returned - - def make_one_shot_iterator(self): - if context.executing_eagerly(): - return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, - self._buffer_size) - else: - return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True, - device=self._device, - buffer_size=self._buffer_size) - - def make_initializable_iterator(self, shared_name=None): - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=False, - device=self._device, - buffer_size=self._buffer_size, - shared_name=shared_name) - - def _as_variant_tensor(self): - # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset - # transformation methods is called. - # TODO(mrry): Investigate support for chaining further transformations after - # the prefetch, including GPU support. - raise NotImplementedError("`prefetch_to_device()` must be the last " - "transformation in a dataset pipeline.") - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes +from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.prefetch_to_device(...)`.") def prefetch_to_device(device, buffer_size=None): """A transformation that prefetches dataset values to the given `device`. @@ -346,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return _PrefetchToDeviceDataset(dataset, device, buffer_size) - - return _apply_fn + return prefetching_ops.prefetch_to_device(device, buffer_size) +@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.") def copy_to_device(target_device, source_device="/cpu:0"): """A transformation that copies dataset elements to the given `target_device`. @@ -363,348 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _CopyToDeviceDataset( - dataset, target_device=target_device, source_device=source_device) - - return _apply_fn - - -# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate -# all inputs to the Op are in host memory, thereby avoiding some unnecessary -# Sends and Recvs. -class _CopyToDeviceDataset(dataset_ops.Dataset): - """A `Dataset` that copies elements to another device.""" - - def __init__(self, input_dataset, target_device, source_device="/cpu:0"): - """Constructs a _CopyToDeviceDataset. - - Args: - input_dataset: `Dataset` to be copied - target_device: The name of the device to which elements would be copied. - source_device: Device where input_dataset would be placed. - """ - self._input_dataset = input_dataset - self._target_device = target_device - spec = framework_device.DeviceSpec().from_string(self._target_device) - self._is_gpu_target = (spec.device_type == "GPU") - self._source_device_string = source_device - self._source_device = ops.convert_to_tensor(source_device) - - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._input_dataset.output_shapes, - self._input_dataset.output_classes)) - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes)) - - @function.Defun() - def _init_func(): - """Creates an iterator for the input dataset. - - Returns: - A `string` tensor that encapsulates the iterator created. - """ - # pylint: disable=protected-access - ds_variant = self._input_dataset._as_variant_tensor() - resource = core_gen_dataset_ops.anonymous_iterator( - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - with ops.control_dependencies( - [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): - return core_gen_dataset_ops.iterator_to_string_handle(resource) - - @function.Defun() - def _remote_init_func(): - return functional_ops.remote_call( - target=self._source_device, - args=_init_func.captured_inputs, - Tout=[dtypes.string], - f=_init_func) - - self._init_func = _remote_init_func - self._init_captured_args = _remote_init_func.captured_inputs - - @function.Defun(dtypes.string) - def _next_func(string_handle): - """Calls get_next for created iterator. - - Args: - string_handle: An iterator string handle created by _init_func - Returns: - The elements generated from `input_dataset` - """ - with ops.device(self._source_device_string): - iterator = iterator_ops.Iterator.from_string_handle( - string_handle, self.output_types, self.output_shapes, - self.output_classes) - ret = iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - @function.Defun(dtypes.string) - def _remote_next_func(string_handle): - return functional_ops.remote_call( - target=self._source_device, - args=[string_handle] + _next_func.captured_inputs, - Tout=self._flat_output_types, - f=_next_func) - - self._next_func = _remote_next_func - self._next_captured_args = _remote_next_func.captured_inputs - - @function.Defun(dtypes.string) - def _finalize_func(string_handle): - """Destroys the iterator resource created. - - Args: - string_handle: An iterator string handle created by _init_func - Returns: - Tensor constant 0 - """ - iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( - string_handle, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - with ops.control_dependencies([ - resource_variable_ops.destroy_resource_op( - iterator_resource, ignore_lookup_error=True)]): - return array_ops.constant(0, dtypes.int64) - - @function.Defun(dtypes.string) - def _remote_finalize_func(string_handle): - return functional_ops.remote_call( - target=self._source_device, - args=[string_handle] + _finalize_func.captured_inputs, - Tout=[dtypes.int64], - f=_finalize_func) - - self._finalize_func = _remote_finalize_func - self._finalize_captured_args = _remote_finalize_func.captured_inputs - - g = ops.get_default_graph() - _remote_init_func.add_to_graph(g) - _remote_next_func.add_to_graph(g) - _remote_finalize_func.add_to_graph(g) - # pylint: enable=protected-scope - - # The one_shot_iterator implementation needs a 0 arg _make_dataset function - # that thereby captures all the inputs required to create the dataset. Since - # there are strings that are inputs to the GeneratorDataset which can't be - # placed on a GPU, this fails for the GPU case. Therefore, disabling it for - # GPU - def make_one_shot_iterator(self): - if self._is_gpu_target: - raise ValueError("Cannot create a one shot iterator when using " - "`tf.contrib.data.copy_to_device()` on GPU. Please use " - "`Dataset.make_initializable_iterator()` instead.") - else: - return super(_CopyToDeviceDataset, self).make_one_shot_iterator() - - def _as_variant_tensor(self): - with ops.device(self._target_device): - return core_gen_dataset_ops.generator_dataset( - self._init_captured_args, - self._next_captured_args, - self._finalize_captured_args, - init_func=self._init_func, - next_func=self._next_func, - finalize_func=self._finalize_func, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -class _PerDeviceGenerator(dataset_ops.Dataset): - """A `dummy` generator dataset.""" - - def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, - source_device, target_device, output_shapes, output_types, - output_classes): - self._target_device = target_device - self._output_types = output_types - self._output_shapes = output_shapes - self._output_classes = output_classes - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._output_types, self._output_classes)) - - multi_device_iterator_string_handle = ( - gen_dataset_ops.multi_device_iterator_to_string_handle( - multi_device_iterator_resource)) - - @function.Defun() - def _init_func(): - return multi_device_iterator_string_handle - - @function.Defun() - def _remote_init_func(): - return functional_ops.remote_call( - target=source_device, - args=_init_func.captured_inputs, - Tout=[dtypes.string], - f=_init_func) - - self._init_func = _remote_init_func - self._init_captured_args = _remote_init_func.captured_inputs - - @function.Defun(dtypes.string) - def _next_func(string_handle): - multi_device_iterator = ( - gen_dataset_ops.multi_device_iterator_from_string_handle( - string_handle=string_handle, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes)) - return gen_dataset_ops.multi_device_iterator_get_next_from_shard( - multi_device_iterator=multi_device_iterator, - shard_num=shard_num, - incarnation_id=incarnation_id, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - @function.Defun(dtypes.string) - def _remote_next_func(string_handle): - return functional_ops.remote_call( - target=source_device, - args=[string_handle] + _next_func.captured_inputs, - Tout=self._flat_output_types, - f=_next_func) - - self._next_func = _remote_next_func - self._next_captured_args = _remote_next_func.captured_inputs - - @function.Defun(dtypes.string) - def _finalize_func(unused_string_handle): - return array_ops.constant(0, dtypes.int64) - - @function.Defun(dtypes.string) - def _remote_finalize_func(string_handle): - return functional_ops.remote_call( - target=source_device, - args=[string_handle] + _finalize_func.captured_inputs, - Tout=[dtypes.int64], - f=_finalize_func) - - self._finalize_func = _remote_finalize_func - self._finalize_captured_args = _remote_finalize_func.captured_inputs - - def _as_variant_tensor(self): - with ops.device(self._target_device): - return core_gen_dataset_ops.generator_dataset( - self._init_captured_args, - self._next_captured_args, - self._finalize_captured_args, - init_func=self._init_func, - next_func=self._next_func, - finalize_func=self._finalize_func, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_classes(self): - return self._output_classes - - -class MultiDeviceIterator(object): - """An iterator over multiple devices.""" - - def __init__(self, - dataset, - devices, - max_buffer_size=1, - prefetch_buffer_size=1, - source_device="/cpu:0"): - """Constructs a MultiDeviceIterator. - - Args: - dataset: The input dataset to be iterated over. - devices: The list of devices to fetch data to. - max_buffer_size: Maximum size of the host side per device buffer to keep. - prefetch_buffer_size: if > 1, then we setup a buffer on each device - to prefetch into. - source_device: The host device to place the `dataset` on. - """ - self._dataset = dataset - self._devices = devices - self._source_device = source_device - self._source_device_tensor = ops.convert_to_tensor(source_device) - - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._dataset.output_shapes, - self._dataset.output_classes)) - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._dataset.output_types, - self._dataset.output_classes)) - - # Create the MultiDeviceIterator. - with ops.device(self._source_device): - self._multi_device_iterator_resource = ( - gen_dataset_ops.multi_device_iterator( - devices=self._devices, - shared_name="", - container="", - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes)) - - # The incarnation ID is used to ensure consistency between the per-device - # iterators and the multi-device iterator. - self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( - self._dataset._as_variant_tensor(), # pylint: disable=protected-access - self._multi_device_iterator_resource, - max_buffer_size=max_buffer_size) - - # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to - # initialize the device side of the pipeline. This would allow the - # MultiDeviceIterator to choose, for example, to move some transformations - # into the device side from its input. It might be useful in rewriting. - # Create the per device iterators. - self._device_iterators = [] - i = 0 - for device in self._devices: - ds = _PerDeviceGenerator( - i, self._multi_device_iterator_resource, self._incarnation_id, - self._source_device_tensor, device, self._dataset.output_shapes, - self._dataset.output_types, self._dataset.output_classes) - if prefetch_buffer_size > 0: - ds = ds.prefetch(prefetch_buffer_size) - with ops.device(device): - self._device_iterators.append(ds.make_initializable_iterator()) - i += 1 - - device_iterator_initializers = [ - iterator.initializer for iterator in self._device_iterators - ] - self._initializer = control_flow_ops.group(*device_iterator_initializers) - - def get_next(self): - result = [] - i = 0 - for device in self._devices: - with ops.device(device): - result.append(self._device_iterators[i].get_next()) - i += 1 - return result - - @property - def initializer(self): - return self._initializer + return prefetching_ops.copy_to_device(target_device, source_device) diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index e670c4c8354f4067eb21c9b1fce708147c162967..2c951256368a5ffdbc2be424cef12eafc6ecd782 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -17,36 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import random_seed -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.data.experimental.ops import random_ops +from tensorflow.python.util import deprecation -class RandomDataset(dataset_ops.Dataset): +class RandomDataset(random_ops.RandomDataset): """A `Dataset` of pseudorandom values.""" + @deprecation.deprecated( + None, "Use `tf.data.experimental.RandomDataset(...)`.") def __init__(self, seed=None): - """A `Dataset` of pseudorandom values.""" - super(RandomDataset, self).__init__() - self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - return gen_dataset_ops.random_dataset( - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.int64 + super(RandomDataset, self).__init__(seed) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4c466781f7f659e8d7e267500a118d482d76da15..4601376dff47e161962e92678883039c4b88bab7 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,297 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import csv - -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops -from tensorflow.contrib.data.python.ops import interleave_ops -from tensorflow.contrib.data.python.ops import parsing_ops -from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.experimental.ops import optimization +from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.platform import gfile +from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation -_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, - dtypes.int64, dtypes.string) - - -def _is_valid_int32(str_val): - try: - # Checks equality to prevent int32 overflow - return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( - str_val) - except (ValueError, OverflowError): - return False - - -def _is_valid_int64(str_val): - try: - dtypes.int64.as_numpy_dtype(str_val) - return True - except (ValueError, OverflowError): - return False - - -def _is_valid_float(str_val, float_dtype): - try: - return float_dtype.as_numpy_dtype(str_val) < np.inf - except ValueError: - return False - - -def _infer_type(str_val, na_value, prev_type): - """Given a string, infers its tensor type. - - Infers the type of a value by picking the least 'permissive' type possible, - while still allowing the previous type inference for this column to be valid. - - Args: - str_val: String value to infer the type of. - na_value: Additional string to recognize as a NA/NaN CSV value. - prev_type: Type previously inferred based on values of this column that - we've seen up till now. - Returns: - Inferred dtype. - """ - if str_val in ("", na_value): - # If the field is null, it gives no extra information about its type - return prev_type - - type_list = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string - ] # list of types to try, ordered from least permissive to most - - type_functions = [ - _is_valid_int32, - _is_valid_int64, - lambda str_val: _is_valid_float(str_val, dtypes.float32), - lambda str_val: _is_valid_float(str_val, dtypes.float64), - lambda str_val: True, - ] # Corresponding list of validation functions - - for i in range(len(type_list)): - validation_fn = type_functions[i] - if validation_fn(str_val) and (prev_type is None or - prev_type in type_list[:i + 1]): - return type_list[i] - - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): - """Generator that yields rows of CSV file(s) in order.""" - for fn in filenames: - with file_io.FileIO(fn, "r") as f: - rdr = csv.reader( - f, - delimiter=field_delim, - quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) - if header: - next(rdr) # Skip header lines - - for csv_row in rdr: - if len(csv_row) != num_cols: - raise ValueError( - "Problem inferring types: CSV row has different number of fields " - "than expected.") - yield csv_row - - -def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, num_rows_for_inference, - select_columns): - """Infers column types from the first N valid CSV records of files.""" - if select_columns is None: - select_columns = range(num_cols) - inferred_types = [None] * len(select_columns) - - for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): - if num_rows_for_inference is not None and i >= num_rows_for_inference: - break - - for j, col_index in enumerate(select_columns): - inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j]) - - # Replace None's with a default type - inferred_types = [t or dtypes.string for t in inferred_types] - # Default to 0 or '' for null values - return [ - constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) - for t in inferred_types - ] - - -def _infer_column_names(filenames, field_delim, use_quote_delim): - """Infers column names from first rows of files.""" - csv_kwargs = { - "delimiter": field_delim, - "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE - } - with file_io.FileIO(filenames[0], "r") as f: - try: - column_names = next(csv.reader(f, **csv_kwargs)) - except StopIteration: - raise ValueError(("Received StopIteration when reading the header line " - "of %s. Empty file?") % filenames[0]) - - for name in filenames[1:]: - with file_io.FileIO(name, "r") as f: - try: - if next(csv.reader(f, **csv_kwargs)) != column_names: - raise ValueError( - "Files have different column names in the header row.") - except StopIteration: - raise ValueError(("Received StopIteration when reading the header line " - "of %s. Empty file?") % filenames[0]) - return column_names - - -def _get_sorted_col_indices(select_columns, column_names): - """Transforms select_columns argument into sorted column indices.""" - names_to_indices = {n: i for i, n in enumerate(column_names)} - num_cols = len(column_names) - for i, v in enumerate(select_columns): - if isinstance(v, int): - if v < 0 or v >= num_cols: - raise ValueError( - "Column index %d specified in select_columns out of valid range." % - v) - continue - if v not in names_to_indices: - raise ValueError( - "Value '%s' specified in select_columns not a valid column index or " - "name." % v) - select_columns[i] = names_to_indices[v] - - # Sort and ensure there are no duplicates - result = sorted(set(select_columns)) - if len(result) != len(select_columns): - raise ValueError("select_columns contains duplicate columns") - return result - - -def _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): - """Optionally shuffle and repeat dataset, as requested.""" - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - return dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - return dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - return dataset.repeat(num_epochs) - return dataset - - -def make_tf_record_dataset( - file_pattern, - batch_size, - parser_fn=None, - num_epochs=None, - shuffle=True, - shuffle_buffer_size=None, - shuffle_seed=None, - prefetch_buffer_size=None, - num_parallel_reads=None, - num_parallel_parser_calls=None, - drop_final_batch=False): - """Reads and optionally parses TFRecord files into a dataset. - - Provides common functionality such as batching, optional parsing, shuffling, - and performant defaults. - - Args: - file_pattern: List of files or patterns of TFRecord file paths. - See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of records to combine - in a single batch. - parser_fn: (Optional.) A function accepting string input to parse - and process the record contents. This function must map records - to components of a fixed shape, so they may be batched. By - default, uses the record contents unmodified. - num_epochs: (Optional.) An int specifying the number of times this - dataset is repeated. If None (the default), cycles through the - dataset forever. - shuffle: (Optional.) A bool that indicates whether the input - should be shuffled. Defaults to `True`. - shuffle_buffer_size: (Optional.) Buffer size to use for - shuffling. A large buffer size ensures better shuffling, but - increases memory usage and startup time. - shuffle_seed: (Optional.) Randomization seed to use for shuffling. - prefetch_buffer_size: (Optional.) An int specifying the number of - feature batches to prefetch for performance improvement. - Defaults to auto-tune. Set to 0 to disable prefetching. - num_parallel_reads: (Optional.) Number of threads used to read - records from files. By default or if set to a value >1, the - results will be interleaved. - num_parallel_parser_calls: (Optional.) Number of parallel - records to parse in parallel. Defaults to an automatic selection. - drop_final_batch: (Optional.) Whether the last batch should be - dropped in case its size is smaller than `batch_size`; the - default behavior is not to drop the smaller batch. - - Returns: - A dataset, where each element matches the output of `parser_fn` - except it will have an additional leading `batch-size` dimension, - or a `batch_size`-length 1-D tensor of strings if `parser_fn` is - unspecified. - """ - files = dataset_ops.Dataset.list_files( - file_pattern, shuffle=shuffle, seed=shuffle_seed) - - if num_parallel_reads is None: - # Note: We considered auto-tuning this value, but there is a concern - # that this affects the mixing of records from different files, which - # could affect training convergence/accuracy, so we are defaulting to - # a constant for now. - num_parallel_reads = 24 - dataset = core_readers.TFRecordDataset( - files, num_parallel_reads=num_parallel_reads) - - if shuffle_buffer_size is None: - # TODO(josh11b): Auto-tune this value when not specified - shuffle_buffer_size = 10000 - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - drop_final_batch = drop_final_batch or num_epochs is None - - if parser_fn is None: - dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) - else: - # TODO(josh11b): if num_parallel_parser_calls is None, use some function - # of num cores instead of map_and_batch's default behavior of one batch. - dataset = dataset.apply(batching.map_and_batch( - parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, - drop_remainder=drop_final_batch)) - - if prefetch_buffer_size is None: - prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE - if prefetch_buffer_size == 0: - return dataset - else: - return dataset.prefetch(buffer_size=prefetch_buffer_size) - +@deprecation.deprecated(None, + "Use `tf.data.experimental.make_csv_dataset(...)`.") def make_csv_dataset( file_pattern, batch_size, @@ -323,7 +46,7 @@ def make_csv_dataset( shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, num_parallel_reads=1, sloppy=False, num_rows_for_inference=100, @@ -386,9 +109,9 @@ def make_csv_dataset( shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. - prefetch_buffer_size: An int specifying the number of feature batches to - prefetch for performance improvement. Recommended value is the number of - batches consumed per training step. + prefetch_buffer_size: An int specifying the number of feature + batches to prefetch for performance improvement. Recommended value is the + number of batches consumed per training step. Defaults to auto-tune. num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. sloppy: If `True`, reading performance will be improved at @@ -412,106 +135,18 @@ def make_csv_dataset( Raises: ValueError: If any of the arguments is malformed. """ - # Create dataset of all matching filenames - filenames = _get_file_names(file_pattern, False) - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) - if shuffle: - dataset = dataset.shuffle(len(filenames), shuffle_seed) - - # Clean arguments; figure out column names and defaults - - if column_names is None: - if not header: - raise ValueError("Cannot infer column names without a header line.") - # If column names are not provided, infer from the header lines - column_names = _infer_column_names(filenames, field_delim, use_quote_delim) - if len(column_names) != len(set(column_names)): - raise ValueError("Cannot have duplicate column names.") - - if select_columns is not None: - select_columns = _get_sorted_col_indices(select_columns, column_names) - - if column_defaults is not None: - column_defaults = [ - constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x - for x in column_defaults - ] - else: - # If column defaults are not provided, infer from records at graph - # construction time - column_defaults = _infer_column_defaults( - filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, num_rows_for_inference, select_columns) - - if select_columns is not None and len(column_defaults) != len(select_columns): - raise ValueError( - "If specified, column_defaults and select_columns must have same " - "length." - ) - if select_columns is not None and len(column_names) > len(select_columns): - # Pick the relevant subset of column names - column_names = [column_names[i] for i in select_columns] - - if label_name is not None and label_name not in column_names: - raise ValueError("`label_name` provided must be one of the columns.") - - def filename_to_dataset(filename): - return CsvDataset( - filename, - record_defaults=column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - header=header, - compression_type=compression_type, - ) - - def map_fn(*columns): - """Organizes columns into a features dictionary. - - Args: - *columns: list of `Tensor`s corresponding to one csv record. - Returns: - An OrderedDict of feature names to values for that particular record. If - label_name is provided, extracts the label feature to be returned as the - second element of the tuple. - """ - features = collections.OrderedDict(zip(column_names, columns)) - if label_name is not None: - label = features.pop(label_name) - return features, label - return features - - # Read files sequentially (if num_parallel_reads=1) or in parallel - dataset = dataset.apply( - interleave_ops.parallel_interleave( - filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map. - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - dataset = dataset.batch(batch_size=batch_size, - drop_remainder=num_epochs is None) - dataset = dataset_ops.MapDataset( - dataset, map_fn, use_inter_op_parallelism=False) - dataset = dataset.prefetch(prefetch_buffer_size) - - return dataset - - -_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + return readers.make_csv_dataset( + file_pattern, batch_size, column_names, column_defaults, label_name, + select_columns, field_delim, use_quote_delim, na_value, header, + num_epochs, shuffle, shuffle_buffer_size, shuffle_seed, + prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference, + compression_type) -class CsvDataset(dataset_ops.Dataset): +class CsvDataset(readers.CsvDataset): """A Dataset comprising lines from one or more CSV files.""" + @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.") def __init__(self, filenames, record_defaults, @@ -522,140 +157,13 @@ class CsvDataset(dataset_ops.Dataset): use_quote_delim=True, na_value="", select_cols=None): - """Creates a `CsvDataset` by reading and decoding CSV files. - - The elements of this dataset correspond to records from the file(s). - RFC 4180 format is expected for CSV files - (https://tools.ietf.org/html/rfc4180) - Note that we allow leading and trailing spaces with int or float field. - - - For example, suppose we have a file 'my_file0.csv' with four CSV columns of - different data types: - ``` - abcdefg,4.28E10,5.55E6,12 - hijklmn,-5.3E14,,2 - ``` - - We can construct a CsvDataset from it as follows: - ```python - dataset = tf.contrib.data.CsvDataset( - "my_file*.csv", - [tf.float32, # Required field, use dtype or empty tensor - tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 - tf.int32, # Required field, use dtype or empty tensor - ], - select_cols=[1,2,3] # Only parse last three columns - ) - ``` - - The expected output of its iterations is: - ```python - next_element = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - - >> (4.28e10, 5.55e6, 12) - >> (-5.3e14, 0.0, 2) - ``` - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - record_defaults: A list of default values for the CSV fields. Each item in - the list is either a valid CSV `DType` (float32, float64, int32, int64, - string), or a `Tensor` object with one of the above types. One per - column of CSV data, with either a scalar `Tensor` default value for the - column if it is optional, or `DType` or empty `Tensor` if required. If - both this and `select_columns` are specified, these must have the same - lengths, and `column_defaults` is assumed to be sorted in order of - increasing column index. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no - compression. - buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes - to buffer while reading files. Defaults to 4MB. - header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) - have header line(s) that should be skipped when parsing. Defaults to - `False`. - field_delim: (Optional.) A `tf.string` scalar containing the delimiter - character that separates fields in a record. Defaults to `","`. - use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats - double quotation marks as regular characters inside of string fields - (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. - na_value: (Optional.) A `tf.string` scalar indicating a value that will - be treated as NA/NaN. - select_cols: (Optional.) A sorted list of column indices to select from - the input data. If specified, only this subset of columns will be - parsed. Defaults to parsing all columns. - """ - super(CsvDataset, self).__init__() - self._filenames = ops.convert_to_tensor( - filenames, dtype=dtypes.string, name="filenames") - self._compression_type = convert.optional_param_to_tensor( - "compression_type", - compression_type, - argument_default="", - argument_dtype=dtypes.string) - record_defaults = [ - constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x - for x in record_defaults - ] - self._record_defaults = ops.convert_n_to_tensor( - record_defaults, name="record_defaults") - self._buffer_size = convert.optional_param_to_tensor( - "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) - self._header = ops.convert_to_tensor( - header, dtype=dtypes.bool, name="header") - self._field_delim = ops.convert_to_tensor( - field_delim, dtype=dtypes.string, name="field_delim") - self._use_quote_delim = ops.convert_to_tensor( - use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") - self._na_value = ops.convert_to_tensor( - na_value, dtype=dtypes.string, name="na_value") - self._select_cols = convert.optional_param_to_tensor( - "select_cols", - select_cols, - argument_default=[], - argument_dtype=dtypes.int64, - ) - self._output_shapes = tuple( - tensor_shape.scalar() for _ in range(len(record_defaults))) - self._output_types = tuple(d.dtype for d in self._record_defaults) - self._output_classes = tuple( - ops.Tensor for _ in range(len(record_defaults))) - - def _as_variant_tensor(self): - # Constructs graph node for the dataset op. - return contrib_gen_dataset_ops.csv_dataset( - filenames=self._filenames, - record_defaults=self._record_defaults, - buffer_size=self._buffer_size, - header=self._header, - output_shapes=self._output_shapes, - field_delim=self._field_delim, - use_quote_delim=self._use_quote_delim, - na_value=self._na_value, - select_cols=self._select_cols, - compression_type=self._compression_type, - ) - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_classes(self): - return self._output_classes + super(CsvDataset, self).__init__( + filenames, record_defaults, compression_type, buffer_size, header, + field_delim, use_quote_delim, na_value, select_cols) +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.") def make_batched_features_dataset(file_pattern, batch_size, features, @@ -666,7 +174,7 @@ def make_batched_features_dataset(file_pattern, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=1, + prefetch_buffer_size=optimization.AUTOTUNE, reader_num_threads=1, parser_num_threads=2, sloppy_ordering=False, @@ -739,7 +247,7 @@ def make_batched_features_dataset(file_pattern, shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: Number of feature batches to prefetch in order to improve performance. Recommended value is the number of batches consumed - per training step (default is 1). + per training step. Defaults to auto-tune. reader_num_threads: Number of threads used to read `Example` records. If >1, the results will be interleaved. parser_num_threads: Number of threads to use for parsing `Example` tensors @@ -760,57 +268,15 @@ def make_batched_features_dataset(file_pattern, Raises: ValueError: If `label_key` is not one of the `features` keys. """ - # Create dataset of all matching filenames - filenames = _get_file_names(file_pattern, False) - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) - if shuffle: - dataset = dataset.shuffle(len(filenames), shuffle_seed) + return readers.make_batched_features_dataset( + file_pattern, batch_size, features, reader, label_key, reader_args, + num_epochs, shuffle, shuffle_buffer_size, shuffle_seed, + prefetch_buffer_size, reader_num_threads, parser_num_threads, + sloppy_ordering, drop_final_batch) - # Read `Example` records from files as tensor objects. - if reader_args is None: - reader_args = [] - # Read files sequentially (if reader_num_threads=1) or in parallel - dataset = dataset.apply( - interleave_ops.parallel_interleave( - lambda filename: reader(filename, *reader_args), - cycle_length=reader_num_threads, - sloppy=sloppy_ordering)) - - # Extract values if the `Example` tensors are stored as key-value tuples. - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset_ops.MapDataset( - dataset, lambda _, v: v, use_inter_op_parallelism=False) - - # Apply dataset repeat and shuffle transformations. - dataset = _maybe_shuffle_and_repeat( - dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to - # improve the shape inference, because it makes the batch dimension static. - # It is safe to do this because in that case we are repeating the input - # indefinitely, and all batches will be full-sized. - dataset = dataset.batch( - batch_size, drop_remainder=drop_final_batch or num_epochs is None) - - # Parse `Example` tensors to a dictionary of `Feature` tensors. - dataset = dataset.apply( - parsing_ops.parse_example_dataset( - features, num_parallel_calls=parser_num_threads)) - - if label_key: - if label_key not in features: - raise ValueError( - "The `label_key` provided (%r) must be one of the `features` keys." % - label_key) - dataset = dataset.map(lambda x: (x, x.pop(label_key))) - - dataset = dataset.prefetch(prefetch_buffer_size) - return dataset - - -@deprecation.deprecated(None, - "Use `tf.contrib.data.make_batched_features_dataset`") +@deprecation.deprecated( + None, "Use `tf.data.experimental.make_batched_features_dataset(...)`") def read_batch_features(file_pattern, batch_size, features, @@ -880,7 +346,7 @@ def read_batch_features(file_pattern, Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - dataset = make_batched_features_dataset( + dataset = readers.make_batched_features_dataset( file_pattern, batch_size, features, @@ -894,99 +360,16 @@ def read_batch_features(file_pattern, return outputs -def _get_file_names(file_pattern, shuffle): - """Parse list of file names from pattern, optionally shuffled. - - Args: - file_pattern: File glob pattern, or list of glob patterns. - shuffle: Whether to shuffle the order of file names. - - Returns: - List of file names matching `file_pattern`. - - Raises: - ValueError: If `file_pattern` is empty, or pattern matches no files. - """ - if isinstance(file_pattern, list): - if not file_pattern: - raise ValueError("File pattern is empty.") - file_names = [] - for entry in file_pattern: - file_names.extend(gfile.Glob(entry)) - else: - file_names = list(gfile.Glob(file_pattern)) - - if not file_names: - raise ValueError("No files match %s." % file_pattern) - - # Sort files so it will be deterministic for unit tests. - if not shuffle: - file_names = sorted(file_names) - return file_names - - -class SqlDataset(dataset_ops.Dataset): +class SqlDataset(readers.SqlDataset): """A `Dataset` consisting of the results from a SQL query.""" + @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.") def __init__(self, driver_name, data_source_name, query, output_types): - """Creates a `SqlDataset`. - - `SqlDataset` allows a user to read data from the result set of a SQL query. - For example: - - ```python - dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", - "SELECT name, age FROM people", - (tf.string, tf.int32)) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - # Prints the rows of the result set of the above query. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - ``` - - Args: - driver_name: A 0-D `tf.string` tensor containing the database type. - Currently, the only supported value is 'sqlite'. - data_source_name: A 0-D `tf.string` tensor containing a connection string - to connect to the database. - query: A 0-D `tf.string` tensor containing the SQL query to execute. - output_types: A tuple of `tf.DType` objects representing the types of the - columns returned by `query`. - """ - super(SqlDataset, self).__init__() - self._driver_name = ops.convert_to_tensor( - driver_name, dtype=dtypes.string, name="driver_name") - self._data_source_name = ops.convert_to_tensor( - data_source_name, dtype=dtypes.string, name="data_source_name") - self._query = ops.convert_to_tensor( - query, dtype=dtypes.string, name="query") - self._output_types = output_types - - def _as_variant_tensor(self): - return gen_dataset_ops.sql_dataset(self._driver_name, - self._data_source_name, self._query, - nest.flatten(self.output_types), - nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return nest.map_structure(lambda _: ops.Tensor, self._output_types) - - @property - def output_shapes(self): - return nest.map_structure(lambda _: tensor_shape.TensorShape([]), - self._output_types) - - @property - def output_types(self): - return self._output_types + super(SqlDataset, self).__init__( + driver_name, data_source_name, query, output_types) -class LMDBDataset(dataset_ops.Dataset): +class LMDBDataset(dataset_ops.DatasetSource): """A LMDB Dataset that reads the lmdb file.""" def __init__(self, filenames): @@ -1014,7 +397,7 @@ class LMDBDataset(dataset_ops.Dataset): filenames, dtype=dtypes.string, name="filenames") def _as_variant_tensor(self): - return contrib_gen_dataset_ops.lmdb_dataset( + return gen_experimental_dataset_ops.experimental_lmdb_dataset( self._filenames, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes)) diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 75642f143e19c3d77e675384362c4dab94e10932..29d77528d95ba62783c1f7c1c0df530ed3929c9e 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -17,22 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import interleave_ops -from tensorflow.contrib.data.python.ops import scan_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.data.experimental.ops import resampling +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.rejection_resample(...)`.") def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): """A transformation that resamples a dataset to achieve a target distribution. @@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") - class_values_ds = dataset.map(class_func) - - # Get initial distribution. - if initial_dist is not None: - initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist, prob_of_original = ( - _calculate_acceptance_probs_with_mixing(initial_dist_t, - target_dist_t)) - initial_dist_ds = dataset_ops.Dataset.from_tensors( - initial_dist_t).repeat() - acceptance_dist_ds = dataset_ops.Dataset.from_tensors( - acceptance_dist).repeat() - prob_of_original_ds = dataset_ops.Dataset.from_tensors( - prob_of_original).repeat() - else: - initial_dist_ds = _estimate_initial_dist_ds( - target_dist_t, class_values_ds) - acceptance_and_original_prob_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs_with_mixing( - initial, target_dist_t)) - acceptance_dist_ds = acceptance_and_original_prob_ds.map( - lambda accept_prob, _: accept_prob) - prob_of_original_ds = acceptance_and_original_prob_ds.map( - lambda _, prob_original: prob_original) - filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, - class_values_ds, seed) - # Prefetch filtered dataset for speed. - filtered_ds = filtered_ds.prefetch(3) - - prob_original_static = _get_prob_original_static( - initial_dist_t, target_dist_t) if initial_dist is not None else None - if prob_original_static == 1: - return dataset_ops.Dataset.zip((class_values_ds, dataset)) - elif prob_original_static == 0: - return filtered_ds - else: - return interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], - weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), - seed=seed) - - return _apply_fn - - -def _get_prob_original_static(initial_dist_t, target_dist_t): - """Returns the static probability of sampling from the original. - - `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters - an Op that it isn't defined for. We have some custom logic to avoid this. - - Args: - initial_dist_t: A tensor of the initial distribution. - target_dist_t: A tensor of the target distribution. - - Returns: - The probability of sampling from the original distribution as a constant, - if it is a constant, or `None`. - """ - init_static = tensor_util.constant_value(initial_dist_t) - target_static = tensor_util.constant_value(target_dist_t) - - if init_static is None or target_static is None: - return None - else: - return np.min(target_static / init_static) - - -def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, - seed): - """Filters a dataset based on per-class acceptance probabilities. - - Args: - dataset: The dataset to be filtered. - acceptance_dist_ds: A dataset of acceptance probabilities. - initial_dist_ds: A dataset of the initial probability distribution, given or - estimated. - class_values_ds: A dataset of the corresponding classes. - seed: (Optional.) Python integer seed for the resampler. - - Returns: - A dataset of (class value, data) after filtering. - """ - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, - initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - def _gather_and_copy(class_val, acceptance_prob, data): - return class_val, array_ops.gather(acceptance_prob, class_val), data - - current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( - (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) - filtered_ds = ( - current_probabilities_and_class_and_data_ds - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) - - -def _estimate_initial_dist_ds( - target_dist_t, class_values_ds, dist_estimation_batch_size=32, - smoothing_constant=10): - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - initial_examples_per_class_seen = array_ops.fill( - [num_classes], np.int64(smoothing_constant)) - - def update_estimate_and_tile(num_examples_per_class_seen, c): - updated_examples_per_class_seen, dist = _estimate_data_distribution( - c, num_examples_per_class_seen) - tiled_dist = array_ops.tile( - array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) - return updated_examples_per_class_seen, tiled_dist - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .apply(scan_ops.scan(initial_examples_per_class_seen, - update_estimate_and_tile)) - .apply(batching.unbatch())) - - return initial_dist_ds - - -def _get_target_to_initial_ratio(initial_probs, target_probs): - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - return target_probs / denom - - -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. - - Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, - containing counts. - - Returns: - num_examples_per_lass_seen: Updated counts. Type `int64`, shape - `[num_classes]`. - dist: The updated distribution. Type `float32`, shape `[num_classes]`. - """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in batch. - num_examples_per_class_seen = math_ops.add( - num_examples_per_class_seen, math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - dist = math_ops.cast(init_prob_estimate, dtypes.float32) - return num_examples_per_class_seen, dist - - -def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): - """Calculates the acceptance probabilities and mixing ratio. - - In this case, we assume that we can *either* sample from the original data - distribution with probability `m`, or sample from a reshaped distribution - that comes from rejection sampling on the original distribution. This - rejection sampling is done on a per-class basis, with `a_i` representing the - probability of accepting data from class `i`. - - This method is based on solving the following analysis for the reshaped - distribution: - - Let F be the probability of a rejection (on any example). - Let p_i be the proportion of examples in the data in class i (init_probs) - Let a_i is the rate the rejection sampler should *accept* class i - Let t_i is the target proportion in the minibatches for class i (target_probs) - - ``` - F = sum_i(p_i * (1-a_i)) - = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 - ``` - - An example with class `i` will be accepted if `k` rejections occur, then an - example with class `i` is seen by the rejector, and it is accepted. This can - be written as follows: - - ``` - t_i = sum_k=0^inf(F^k * p_i * a_i) - = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 - = p_i * a_i / sum_j(p_j * a_j) using F from above - ``` - - Note that the following constraints hold: - ``` - 0 <= p_i <= 1, sum_i(p_i) = 1 - 0 <= a_i <= 1 - 0 <= t_i <= 1, sum_i(t_i) = 1 - ``` - - A solution for a_i in terms of the other variables is the following: - ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - - If we try to minimize the amount of data rejected, we get the following: - - M_max = max_i [ t_i / p_i ] - M_min = min_i [ t_i / p_i ] - - The desired probability of accepting data if it comes from class `i`: - - a_i = (t_i/p_i - m) / (M_max - m) - - The desired probability of pulling a data element from the original dataset, - rather than the filtered one: - - m = M_min - - Args: - initial_probs: A Tensor of the initial probability distribution, given or - estimated. - target_probs: A Tensor of the corresponding classes. - - Returns: - (A 1D Tensor with the per-class acceptance probabilities, the desired - probability of pull from the original distribution.) - """ - ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) - max_ratio = math_ops.reduce_max(ratio_l) - min_ratio = math_ops.reduce_min(ratio_l) - - # Target prob to sample from original distribution. - m = min_ratio - - # TODO(joelshor): Simplify fraction, if possible. - a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m + return resampling.rejection_resample(class_func, target_dist, initial_dist, + seed) diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 6b002b4a533669dd0f5e82a00aa29224a83a7e57..0ca9fddb23b20995bdcd4d45aa675537111c4552 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -17,137 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_dataset_ops - - -class _ScanDataset(dataset_ops.Dataset): - """A dataset that scans a function across its input.""" - - def __init__(self, input_dataset, initial_state, scan_func): - """See `scan()` for details.""" - super(_ScanDataset, self).__init__() - self._input_dataset = input_dataset - - with ops.name_scope("initial_state"): - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - self._initial_state = nest.pack_sequence_as(initial_state, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( - t, name="component_%d" % i) - for i, t in enumerate(nest.flatten(initial_state)) - ]) - - # Compute initial values for the state classes, shapes and types based on - # the initial state. The shapes may be refined by running `tf_scan_func` one - # or more times below. - self._state_classes = sparse.get_classes(self._initial_state) - self._state_shapes = nest.pack_sequence_as( - self._initial_state, - [t.get_shape() for t in nest.flatten(self._initial_state)]) - self._state_types = nest.pack_sequence_as( - self._initial_state, - [t.dtype for t in nest.flatten(self._initial_state)]) - - # Will be populated by calling `tf_scan_func`. - self._output_classes = None - self._output_shapes = None - self._output_types = None - - # Iteratively rerun the scan function until reaching a fixed point on - # `self._state_shapes`. - need_to_rerun = True - while need_to_rerun: - - wrapped_func = dataset_ops.StructuredFunctionWrapper( - scan_func, "tf.contrib.data.scan()", - input_classes=(self._state_classes, input_dataset.output_classes), - input_shapes=(self._state_shapes, input_dataset.output_shapes), - input_types=(self._state_types, input_dataset.output_types), - add_to_graph=False) - if not ( - isinstance(wrapped_func.output_types, collections.Sequence) and - len(wrapped_func.output_types) == 2): - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - new_state_classes, self._output_classes = wrapped_func.output_classes - - # Extract and validate class information from the returned values. - for new_state_class, state_class in zip( - nest.flatten(new_state_classes), - nest.flatten(self._state_classes)): - if not issubclass(new_state_class, state_class): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, new_state_classes)) - - # Extract and validate type information from the returned values. - new_state_types, self._output_types = wrapped_func.output_types - for new_state_type, state_type in zip( - nest.flatten(new_state_types), nest.flatten(self._state_types)): - if new_state_type != state_type: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, new_state_types)) - - # Extract shape information from the returned values. - new_state_shapes, self._output_shapes = wrapped_func.output_shapes - - flat_state_shapes = nest.flatten(self._state_shapes) - flat_new_state_shapes = nest.flatten(new_state_shapes) - weakened_state_shapes = [ - original.most_specific_compatible_shape(new) - for original, new in zip(flat_state_shapes, flat_new_state_shapes) - ] - - need_to_rerun = False - for original_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if original_shape.ndims is not None and ( - weakened_shape.ndims is None or - original_shape.as_list() != weakened_shape.as_list()): - need_to_rerun = True - break - - if need_to_rerun: - self._state_shapes = nest.pack_sequence_as(self._state_shapes, - weakened_state_shapes) - - self._scan_func = wrapped_func.function - self._scan_func.add_to_graph(ops.get_default_graph()) - - def _as_variant_tensor(self): - input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access - return gen_dataset_ops.scan_dataset( - input_t, - nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), - self._scan_func.captured_inputs, - f=self._scan_func, - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types +from tensorflow.python.data.experimental.ops import scan_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.") def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. @@ -168,7 +42,4 @@ def scan(initial_state, scan_func): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - def _apply_fn(dataset): - return _ScanDataset(dataset, initial_state, scan_func) - - return _apply_fn + return scan_ops.scan(initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 4356721704046199e8ef2938bde6d7d8bce68cc1..329b34fdfecf026688c3ebd210d3400a427940a8 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -17,59 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import random_seed -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops - - -class _ShuffleAndRepeatDataset(dataset_ops.Dataset): - """A `Dataset` that fuses `shuffle` and `repeat`.""" - - def __init__(self, - input_dataset, - buffer_size, - count=None, - seed=None): - """See `Dataset.map()` for details.""" - super(_ShuffleAndRepeatDataset, self).__init__() - self._input_dataset = input_dataset - self._buffer_size = ops.convert_to_tensor( - buffer_size, dtype=dtypes.int64, name="buffer_size") - if count is None: - self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") - else: - self._count = ops.convert_to_tensor( - count, dtype=dtypes.int64, name="count") - self._seed, self._seed2 = random_seed.get_seed(seed) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.shuffle_and_repeat_dataset( - input_resource, - buffer_size=self._buffer_size, - count=self._count, - seed=self._seed, - seed2=self._seed2, - **dataset_ops.flat_structure(self)) - # pylint: enable=protected-access - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types +from tensorflow.python.data.experimental.ops import shuffle_ops +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, + "Use `tf.data.experimental.shuffle_and_repeat(...)`.") def shuffle_and_repeat(buffer_size, count=None, seed=None): """Shuffles and repeats a Dataset returning a new permutation for each epoch. @@ -98,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): # pylint: disable=missing-docstring - return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) - - return _apply_fn + return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 8025dcdd16b0180aeb951a31de21e22b8e8c31c7..bcc383587c54bd89502313f9328bc06c49046a87 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation -class _SlideDataset(dataset_ops.Dataset): +class _SlideDataset(dataset_ops.UnaryDataset): """A `Dataset` that passes a sliding window over its input.""" def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__() + super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") @@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset): @deprecation.deprecated_args( None, "stride is deprecated, use window_shift instead", "stride") +@deprecation.deprecated( + None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, " + "stride=window_stride).flat_map(lambda x: x.batch(window.size))` " + "instead.") def sliding_window_batch(window_size, stride=None, window_shift=None, diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index dc67accdcfbc2692cbe0c961521897a316f40647..20cceb4647ae6d5f80a9dbac3baed72d50254f09 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -17,89 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context -from tensorflow.python.ops import resource_variable_ops - -_uid_counter = 0 -_uid_lock = threading.Lock() - - -def _generate_shared_name(prefix): - with _uid_lock: - global _uid_counter - uid = _uid_counter - _uid_counter += 1 - return "{}{}".format(prefix, uid) - - -# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. -class PrivateThreadPool(object): - """A stateful resource that represents a private thread pool.""" - - def __init__(self, num_threads, display_name=None, - max_intra_op_parallelism=1): - """Creates a `PrivateThreadPool` with the given number of threads.""" - if context.executing_eagerly(): - shared_name = _generate_shared_name("privatethreadpool") - self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, - max_intra_op_parallelism=max_intra_op_parallelism, - display_name=display_name, - shared_name=shared_name) - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=context.context().device_name) - else: - self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, - max_intra_op_parallelism=max_intra_op_parallelism, - display_name=display_name) - - -class _ThreadPoolDataset(dataset_ops.Dataset): - """A `Dataset` that acts as an identity, and sets a custom threadpool.""" - - def __init__(self, input_dataset, thread_pool): - super(_ThreadPoolDataset, self).__init__() - self._input_dataset = input_dataset - self._thread_pool = thread_pool - - def _as_variant_tensor(self): - return gen_dataset_ops.thread_pool_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._thread_pool._resource, # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. -def override_threadpool(dataset, thread_pool): - """Returns a new dataset that uses the given thread pool for its operations. - - Args: - dataset: A `tf.data.Dataset` object. - thread_pool: A `PrivateThreadPool` object. - - Returns: - A dataset containing the same values as `dataset`, but which uses - `thread_pool` to compute any of its parallel operations (such as - `tf.data.Dataset.map`). - """ - return _ThreadPoolDataset(dataset, thread_pool) +# pylint: disable=unused-import +from tensorflow.python.data.experimental.ops.threadpool import override_threadpool +from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index e0d606311c4f2f678970113c1faa578dbf44b2ba..909d06c677ea29733966e0c19a7543b149d2fe74 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes +from tensorflow.python.data.experimental.ops import unique as experimental_unique +from tensorflow.python.util import deprecation +@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.") def unique(): """Creates a `Dataset` from another `Dataset`, discarding duplicates. @@ -40,39 +39,4 @@ def unique(): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ - - def _apply_fn(dataset): - return _UniqueDataset(dataset) - - return _apply_fn - - -class _UniqueDataset(dataset_ops.Dataset): - """A `Dataset` contains the unique elements from its input.""" - - def __init__(self, input_dataset): - """See `unique()` for details.""" - super(_UniqueDataset, self).__init__() - self._input_dataset = input_dataset - if input_dataset.output_types not in (dtypes.int32, dtypes.int64, - dtypes.string): - raise TypeError( - "`tf.contrib.data.unique()` only supports inputs with a single " - "`tf.int32`, `tf.int64`, or `tf.string` component.") - - def _as_variant_tensor(self): - return gen_dataset_ops.unique_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - **dataset_ops.flat_structure(self)) - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types + return experimental_unique.unique() diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py index c455fdcba673853079ff0d162c4799e72bc8e627..42fb69bf077afbd2094f6eb1bf3fe7b17f761910 100644 --- a/tensorflow/contrib/data/python/ops/writers.py +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -17,42 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.data.experimental.ops import writers +from tensorflow.python.util import deprecation -class TFRecordWriter(object): +class TFRecordWriter(writers.TFRecordWriter): """Writes data to a TFRecord file.""" + @deprecation.deprecated( + None, "Use `tf.data.experimental.TFRecordWriter(...)`.") def __init__(self, filename, compression_type=None): - self._filename = ops.convert_to_tensor( - filename, dtypes.string, name="filename") - self._compression_type = convert.optional_param_to_tensor( - "compression_type", - compression_type, - argument_default="", - argument_dtype=dtypes.string) - - def write(self, dataset): - """Returns a `tf.Operation` to write a dataset to a file. - - Args: - dataset: a `tf.data.Dataset` whose elements are to be written to a file - - Returns: - A `tf.Operation` that, when run, writes contents of `dataset` to a file. - """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - if (dataset.output_types != dtypes.string or - dataset.output_shapes != tensor_shape.scalar()): - raise TypeError( - "`dataset` must produce scalar `DT_STRING` tensors whereas it " - "produces shape {0} and types {1}".format(dataset.output_shapes, - dataset.output_types)) - return gen_dataset_ops.dataset_to_tf_record( - dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access + super(TFRecordWriter, self).__init__(filename, compression_type) diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 3b50a48336d77ebd9327fa24e5612a95d5d0c372..06940a90d5cf5c193b026a20c3a5fa41e778b0a9 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -17,7 +17,6 @@ tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, - java_api_version = 2, visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/deprecated/summaries_test.py b/tensorflow/contrib/deprecated/summaries_test.py index 6acf2a6469c3cb27541721ddc5962e6879a88469..4038224a1c79a09a6f27c154be435f6dffd13d6c 100644 --- a/tensorflow/contrib/deprecated/summaries_test.py +++ b/tensorflow/contrib/deprecated/summaries_test.py @@ -27,31 +27,31 @@ from tensorflow.python.platform import test class DeprecatedSummariesTest(test.TestCase): def testScalarSummary(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(3) s = logging_ops.scalar_summary('tag', c) self.assertEqual(s.op.type, u'ScalarSummary') def testHistogramSummary(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(3) s = logging_ops.histogram_summary('tag', c) self.assertEqual(s.op.type, u'HistogramSummary') def testImageSummary(self): - with self.test_session(): + with self.cached_session(): i = array_ops.ones((5, 4, 4, 3)) s = logging_ops.image_summary('tag', i) self.assertEqual(s.op.type, u'ImageSummary') def testAudioSummary(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(3.0) s = logging_ops.audio_summary('tag', c, sample_rate=8000) self.assertEqual(s.op.type, u'AudioSummaryV2') def testMergeSummary(self): - with self.test_session(): + with self.cached_session(): c = constant_op.constant(3) a = logging_ops.scalar_summary('a', c) b = logging_ops.scalar_summary('b', c) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 30e1992c015d35859218d1b7fe3b2f3eb7c09b9b..2e025765e4aaab7114aa6e3e79336e48a71b5b55 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -76,7 +76,7 @@ We then compile the Keras model and pass the `MirroredStrategy` object in the ```python model.compile(loss='mean_squared_error', optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=strategy) + distribute=distribution) ``` To train the model we call Keras `fit` API using the input dataset that we @@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use important to shuffle your dataset in your `input_fn`. `MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you -`input_fn`. As a result, each worker gets a fraction of your input data. +`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker +gets a fraction of your input data. ### Performance Tips diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 350f81f60f84a74b7d2b9211dd92f6287cc8dc6d..823fe6a917f4f31ab6822e4bb1130d62ff45f0c9 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Prototype of a distributed computation library for TF.""" +"""A distributed computation library for TF. + +See [tensorflow/contrib/distribute/README.md]( +https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md) +for overview and examples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 87f76eaa948e09f2a3b7fd0ba52a154824e9fe33..76d5b59ce17279b7c6d2d930504153fc31deb8e2 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -28,6 +28,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", @@ -410,6 +411,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "moving_averages_test", + srcs = ["moving_averages_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + tags = [ + "no_pip", + ], +) + cuda_py_test( name = "optimizer_v2_test", srcs = ["optimizer_v2_test.py"], @@ -453,7 +472,7 @@ cuda_py_test( cuda_py_test( name = "estimator_training_test", - size = "large", + size = "enormous", srcs = ["estimator_training_test.py"], additional_deps = [ ":combinations", @@ -472,11 +491,8 @@ cuda_py_test( "//tensorflow/python:summary", ], tags = [ - "manual", "multi_and_single_gpu", "no_pip", - "nogpu", - "notap", ], ) @@ -485,7 +501,6 @@ py_library( srcs = ["single_loss_example.py"], deps = [ ":step_fn", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", @@ -655,8 +670,8 @@ py_library( name = "prefetching_ops_v2", srcs = ["prefetching_ops_v2.py"], deps = [ - "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", @@ -731,6 +746,7 @@ cuda_py_test( additional_deps = [ ":keras_test_lib", ], + shard_count = 16, tags = [ "multi_and_single_gpu", "no_pip", @@ -739,18 +755,27 @@ cuda_py_test( ], ) -cuda_py_test( - name = "metrics_v1_test", +py_library( + name = "metrics_v1_test_lib", + testonly = 1, srcs = ["metrics_v1_test.py"], - additional_deps = [ + deps = [ ":combinations", - "@absl_py//absl/testing:parameterized", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":metrics_v1_test_lib", ], tags = [ "multi_and_single_gpu", diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 77079d0df9a94254384e75b98a0f6432189f05d8..9809204f8f107270b5a7b51e65e06afdae7d96b8 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): def _real_mirrored_creator(devices, *args, **kwargs): """Creates one MirroredVariable on the current worker.""" index = {} + unique_var_name = ops.get_default_graph().unique_name( + kwargs["name"], mark_as_used=False).rstrip("/") collective_instance_key = self._collective_keys.get_instance_key( - key_id=kwargs["name"]) + key_id=unique_var_name) if "initial_value" not in kwargs: raise ValueError("Initial value must be specified.") initial_value = kwargs["initial_value"] @@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) + if i == 0: + actual_var_name = v.name.split(":")[0] + assert unique_var_name == actual_var_name, "%r vs %r" % ( + unique_var_name, actual_var_name) assert not isinstance(v, values.DistributedVariable) index[d] = v return index @@ -210,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): """Configures the object. Args: - session_config: a @{tf.ConfigProto} + session_config: a `tf.ConfigProto` cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the cluster configurations. task_type: the current task type, such as "worker". @@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if not session_config or not self._cluster_spec: return - session_config.isolate_session_state = True - assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 36e976107309f51a1772c939ea329d55494f552a..6796a23d464d344554ae9654e0992e30df5ad213 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +35,14 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import training_util class CollectiveAllReduceStrategyTestBase( @@ -122,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list @@ -146,6 +153,56 @@ class CollectiveAllReduceStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_complex_model(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + + def model_fn(): + """Mnist model with synthetic input.""" + data_format = 'channels_last' + input_shape = [28, 28, 1] + l = keras.layers + max_pool = l.MaxPooling2D((2, 2), (2, 2), + padding='same', + data_format=data_format) + model = keras.Sequential([ + l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)), + l.Conv2D( + 32, + 5, + padding='same', + data_format=data_format, + activation=nn.relu), max_pool, + l.Conv2D( + 64, + 5, + padding='same', + data_format=data_format, + activation=nn.relu), max_pool, + l.Flatten(), + l.Dense(1024, activation=nn.relu), + l.Dropout(0.4), + l.Dense(10) + ]) + image = random_ops.random_uniform([2, 28, 28]) + label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) + logits = model(image, training=True) + loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) + optimizer = adam.AdamOptimizer(learning_rate=1e-4) + train_op = optimizer.minimize(loss, + training_util.get_or_create_global_step()) + return train_op + + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess: + with d.scope(): + train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + return True + def _test_variable_initialization(self, task_type, task_id, num_gpus): distribution, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -206,6 +263,14 @@ class DistributedCollectiveAllReduceStrategyTest( self._cluster_spec, num_gpus=num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testComplexModel(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -236,6 +301,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._cluster_spec, num_gpus=num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testComplexModel(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + class LocalCollectiveAllReduceStrategy( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -246,6 +319,12 @@ class LocalCollectiveAllReduceStrategy( return self._test_minimize_loss_graph(None, None, num_gpus) + def testComplexModel(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + self._test_complex_model(None, None, num_gpus) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 244d1fcec8ba481337afeede181c29d0552e3c44..63a163e76cdd99c73399c657cbe9bc3d010369d2 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent +from tensorflow.python.training import rmsprop from tensorflow.python.util import tf_inspect @@ -328,10 +329,10 @@ one_device_strategy = NamedDistribution( required_gpus=None) tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=5), + TPUClusterResolver(""), steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( + "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) # Note that we disable prefetching for testing since prefetching makes @@ -348,24 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution( required_gpus=2) -adam_optimizer_v1_fn = NamedObject( - "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) adagrad_optimizer_v1_fn = NamedObject( "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) -optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, - adagrad_optimizer_v1_fn] +adam_optimizer_v1_fn = NamedObject("AdamV1", + lambda: adam.AdamOptimizer(0.001, epsilon=1)) +rmsprop_optimizer_v1_fn = NamedObject( + "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) + +optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn] -adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) -optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn, - adagrad_optimizer_v2_fn] +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + +optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 24cb08fb48f832572da5ae2113e6c224557c6a81..9fc1b8895516f64a956accd9290e7bf42ccef330 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -221,9 +221,12 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads -# threading.Lock() cannot be pickled and therefore cannot be a field of -# CollectiveKeys. +# threading.Lock() and threading.local() cannot be pickled and therefore cannot +# be a field of CollectiveKeys. Right now _thread_local is not necessary to be +# an instance member of CollectiveKeys since we always create a new thread for +# each tower. _lock = threading.Lock() +_thread_local = threading.local() # TODO(yuefengz): use random key starts to avoid reusing keys? @@ -266,14 +269,12 @@ class CollectiveKeys(object): # For instance keys without ids self._instance_key_start = instance_key_start - self._thread_local = threading.local() - def _get_thread_local_object(self): # We make instance key without key ids thread local so that it will work # with MirroredStrategy and distribute coordinator. - if not hasattr(self._thread_local, 'instance_key'): - self._thread_local.instance_key = self._instance_key_start - return self._thread_local + if not hasattr(_thread_local, 'instance_key'): + _thread_local.instance_key = self._instance_key_start + return _thread_local def get_group_key(self, devices): """Returns a group key for the set of devices. diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 5348512016efc504f92e5a956d627698b93b209a..157618f72ff2ea6dde171e7edb62ccaf7e1de516 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -26,21 +26,12 @@ import tempfile import threading from absl.testing import parameterized import numpy as np -import six -_portpicker_import_error = None -try: - import portpicker # pylint: disable=g-import-not-at-top -except ImportError as _error: # pylint: disable=invalid-name - _portpicker_import_error = _error - portpicker = None - -# pylint: disable=g-import-not-at-top from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.optimizer_v2 import adagrad -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import estimator_training as dc_training @@ -57,7 +48,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import server_lib BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -73,130 +63,38 @@ EVALUATOR = dc._TaskType.EVALUATOR WORKER = dc._TaskType.WORKER PS = dc._TaskType.PS -original_run_distribute_coordinator = dc.run_distribute_coordinator - - -# TODO(yuefengz): merge this method back to test_util. -def _create_local_cluster(num_workers, - num_ps, - has_eval=False, - protocol="grpc", - worker_config=None, - ps_config=None): - if _portpicker_import_error: - raise _portpicker_import_error # pylint: disable=raising-bad-type - worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] - ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] - - cluster_dict = { - "worker": ["localhost:%s" % port for port in worker_ports], - "ps": ["localhost:%s" % port for port in ps_ports] - } - if has_eval: - cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()] - - cs = server_lib.ClusterSpec(cluster_dict) - - workers = [ - server_lib.Server( - cs, - job_name="worker", - protocol=protocol, - task_index=ix, - config=worker_config, - start=True) for ix in range(num_workers) - ] - ps_servers = [ - server_lib.Server( - cs, - job_name="ps", - protocol=protocol, - task_index=ix, - config=ps_config, - start=True) for ix in range(num_ps) - ] - if has_eval: - evals = [ - server_lib.Server( - cs, - job_name="evaluator", - protocol=protocol, - task_index=0, - config=worker_config, - start=True) - ] - else: - evals = [] - - return workers, ps_servers, evals - - -def _create_in_process_cluster(num_workers, num_ps, has_eval=False): - """Create an in-process cluster that consists of only standard server.""" - # Leave some memory for cuda runtime. - if has_eval: - gpu_mem_frac = 0.7 / (num_workers + 1) - else: - gpu_mem_frac = 0.7 / num_workers - - worker_config = config_pb2.ConfigProto() - worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # Enable collective ops which has no impact on non-collective ops. - # TODO(yuefengz, tucker): removing this after we move the initialization of - # collective mgr to the session level. - worker_config.experimental.collective_group_leader = ( - "/job:worker/replica:0/task:0") - - ps_config = config_pb2.ConfigProto() - ps_config.device_count["GPU"] = 0 - - return _create_local_cluster( - num_workers, - num_ps=num_ps, - has_eval=has_eval, - worker_config=worker_config, - ps_config=ps_config, - protocol="grpc") - - -def _create_cluster_spec(has_chief=False, - num_workers=1, - num_ps=0, - has_eval=False): - if _portpicker_import_error: - raise _portpicker_import_error # pylint: disable=raising-bad-type - - cluster_spec = {} - if has_chief: - cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()] - if num_workers: - cluster_spec[WORKER] = [ - "localhost:%s" % portpicker.pick_unused_port() - for _ in range(num_workers) - ] - if num_ps: - cluster_spec[PS] = [ - "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps) - ] - if has_eval: - cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] - return cluster_spec +original_run_std_server = dc._run_std_server -def _bytes_to_str(maybe_bytes): - if isinstance(maybe_bytes, six.string_types): - return maybe_bytes - else: - return str(maybe_bytes, "utf-8") +class MockOsEnv(dict): + + def __init__(self, *args): + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self, key, default) + def __getitem__(self, key): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self, key) -def _strip_protocol(target): - # cluster_spec expects "host:port" strings. - if "//" in target: - return target.split("//")[1] - else: - return target + def __setitem__(self, key, val): + if not hasattr(self._thread_local, "dict"): + self._thread_local.dict = dict() + if key == "TF_CONFIG": + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self, key, val) class DistributeCoordinatorIntegrationTest(test.TestCase, @@ -205,22 +103,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps, cls._evals = _create_in_process_cluster( + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) - cls._cluster_spec = { - "worker": [ - _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers - ], - "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps], - "evaluator": [ - _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals - ] - } def setUp(self): self._model_dir = tempfile.mkdtemp() - self._event = threading.Event() + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, "environ", + self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -391,43 +287,17 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_distribute, eval_distribute, remote_cluster=self._cluster_spec) self._inspect_train_and_eval_events(estimator) - def _mock_run_distribute_coordinator( - self, - worker_fn, - strategy, - eval_fn, - eval_strategy, - mode=dc.CoordinatorMode.STANDALONE_CLIENT, - cluster_spec=None, - session_config=None): - # Calls the origial `run_distribute_coordinator` method but gets task config - # from environment variables and then signals the caller. - task_type = None - task_id = None - if not cluster_spec: - cluster_spec = None - tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) - if not cluster_spec: - cluster_spec = tf_config.get("cluster", {}) - task_env = tf_config.get("task", {}) - if task_env: - task_type = task_env.get("type", task_type) - task_id = int(task_env.get("index", task_id)) - self._event.set() - original_run_distribute_coordinator( - worker_fn, - strategy, - eval_fn, - eval_strategy, - mode=mode, - cluster_spec=cluster_spec, - task_type=task_type, - task_id=task_id, - session_config=session_config) - - def _task_thread(self, train_distribute, eval_distribute): - with test.mock.patch.object(dc, "run_distribute_coordinator", - self._mock_run_distribute_coordinator): + def _mock_run_std_server(self, *args, **kwargs): + ret = original_run_std_server(*args, **kwargs) + # Wait for all std servers to be brought up in order to reduce the chance of + # remote sessions taking local ports that have been assigned to std servers. + self._barrier.wait() + return ret + + def _task_thread(self, train_distribute, eval_distribute, tf_config): + os.environ["TF_CONFIG"] = json.dumps(tf_config) + with test.mock.patch.object(dc, "_run_std_server", + self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) def _run_task_in_thread(self, cluster_spec, task_type, task_id, @@ -448,13 +318,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, "index": task_id } } - self._event.clear() t = threading.Thread( - target=self._task_thread, args=(train_distribute, eval_distribute)) - with test.mock.patch.dict("os.environ", - {"TF_CONFIG": json.dumps(tf_config)}): - t.start() - self._event.wait() + target=self._task_thread, + args=(train_distribute, eval_distribute, tf_config)) + t.start() return t def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, @@ -489,7 +356,11 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, else: eval_distribute = None - cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=2, has_eval=True) + # 3 workers, 2 ps and 1 evaluator. + self._barrier = dc._Barrier(6) + threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) for task_type, ts in threads.items(): @@ -516,7 +387,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, else: eval_distribute = None - cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=0, has_eval=True) + # 3 workers and 1 evaluator. + self._barrier = dc._Barrier(4) threads = self._run_multiple_tasks_in_threads( cluster_spec, train_distribute, eval_distribute) threads[WORKER][0].join() diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a84ef041960e389c08246fc8a16df2300856d968..da7f8c548f94972b6ec0a67848e1520386d1e28b 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -113,7 +113,7 @@ def main(_): distribute=strategy) # Train the model with the train dataset. - model.fit(x=train_ds, epochs=20, steps_per_epoch=310) + model.fit(x=train_ds, epochs=20, steps_per_epoch=468) # Evaluate the model with the eval dataset. score = model.evaluate(eval_ds, steps=10, verbose=0) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 44a69ed23a4e00ab81d5b51ae0c14550bd493f14..79a9803d75a35445280c006fa023637c9b01fdcc 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -22,6 +22,8 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.keras import metrics as metrics_module + def build_model_fn_optimizer(): """Simple model_fn with optimizer.""" @@ -45,7 +47,10 @@ def build_model_fn_optimizer(): return y * y if mode == tf.estimator.ModeKeys.EVAL: - return tf.estimator.EstimatorSpec(mode, loss=loss_fn()) + acc_obj = metrics_module.BinaryAccuracy() + acc_obj.update_state(labels, labels) + return tf.estimator.EstimatorSpec( + mode, loss=loss_fn(), eval_metric_ops={"Accuracy": acc_obj}) assert mode == tf.estimator.ModeKeys.TRAIN @@ -61,18 +66,26 @@ def main(_): ["/device:GPU:0", "/device:GPU:1"]) config = tf.estimator.RunConfig(train_distribute=distribution, eval_distribute=distribution) + # Since there are 2 devices and 10 samples, we set steps=5. + steps = 5 - def input_fn(): + def train_input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) labels = tf.data.Dataset.from_tensors([1.]).repeat(10) return tf.data.Dataset.zip((features, labels)) estimator = tf.estimator.Estimator( model_fn=build_model_fn_optimizer(), config=config) - estimator.train(input_fn=input_fn, steps=10) + estimator.train(input_fn=train_input_fn, steps=steps) + + def eval_input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + labels = tf.data.Dataset.from_tensors([1.]).repeat(10) + return tf.data.Dataset.zip((features, labels)) - eval_result = estimator.evaluate(input_fn=input_fn, steps=10) + eval_result = estimator.evaluate(input_fn=eval_input_fn, steps=steps) print("Eval result: {}".format(eval_result)) + assert eval_result["Accuracy"] == 1.0 def predict_input_fn(): predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py index c5acb7ced4bcb58cf327398f04fb37675a944e97..559de97bb1f93f990ddaf775d9203d5a2d46aa99 100644 --- a/tensorflow/contrib/distribute/python/input_ops_test.py +++ b/tensorflow/contrib/distribute/python/input_ops_test.py @@ -20,8 +20,6 @@ from __future__ import print_function import os -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.distribute.python import input_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers @@ -126,20 +124,6 @@ class AutoShardDatasetTest(test.TestCase): # contain records in order of files. self._verifySimpleShardingOutput(dataset, self._record) - def testParallelInterleave(self): - dataset = dataset_ops.Dataset.from_tensor_slices( - self._createTFRecordFiles()) - dataset = dataset.apply(interleave_ops.parallel_interleave( - readers.TFRecordDataset, - cycle_length=4, - block_length=self._num_records)) - dataset = input_ops.auto_shard_dataset( - dataset, self._num_shards, self._shard_index) - - # Since block_length == num records in each file, the output will still - # contain records in order of files. - self._verifySimpleShardingOutput(dataset, self._record) - def testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" @@ -171,8 +155,8 @@ class AutoShardDatasetTest(test.TestCase): dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.shuffle(2 * self._num_files * self._num_records) dataset = dataset.repeat(num_epochs) - dataset = dataset.apply(batching.map_and_batch( - lambda x: x, batch_size=batch_size)) + dataset = dataset.map(lambda x: x) + dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size=None) # Auto shard. diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 5f35e381899a03f12cf0a6ed0168b9e500d41801..6553642ad320e40195a87420f3c3e51f439b8a8f 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -173,13 +173,50 @@ def batch_wrapper(dataset, batch_size, distribution): return dataset.batch(batch_size) -def all_combinations(): +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +def get_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def get_predict_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +strategies = [combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy_one_step] + + +def strategy_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.tpu_strategy_one_step], + distribution=strategies, + mode=['graph']) + + +def strategy_and_optimizer_combinations(): + return combinations.combine( + distribution=strategies, + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn], mode=['graph']) @@ -205,6 +242,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): keras_model = simple_functional_model() keras_model.compile( loss='categorical_crossentropy', + metrics=[keras.metrics.CategoricalAccuracy()], optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, @@ -229,6 +267,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', + metrics=[keras.metrics.CategoricalAccuracy()], optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, @@ -316,58 +355,27 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._config.model_dir) -class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): - - def test_validating_dataset_input_tensors_with_shape_mismatch(self): - with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) +class TestDistributionStrategyWithNumpyArrays(test.TestCase, + parameterized.TestCase): - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + @combinations.generate(strategy_combinations()) + def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) - - def test_calling_model_with_numpy_arrays(self): + x = np.asarray(np.random.random((64, 3)), dtype=np.float32) + var_x = distributed_training_utils.get_var_for_numpy(distribution, x) + val = self.evaluate(var_x.value()) + # Verify that the numpy value is copied to the variable. + self.assertAllEqual(x, val) + + @combinations.generate(strategy_combinations()) + def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) @@ -390,30 +398,70 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) + def test_calling_model_with_nested_numpy_arrays(self, distribution): + with self.cached_session(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) + targets = [output_d_np, output_e_np] + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + +class TestDistributionStrategyWithDatasets(test.TestCase, + parameterized.TestCase): + + @combinations.generate(strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', keras.metrics.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 10, distribution) + dataset = get_dataset(distribution) # Call fit with validation data model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) - model.predict(dataset, steps=2) + model.predict(get_predict_dataset(distribution), steps=2) # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not @@ -432,7 +480,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', keras.metrics.CategoricalAccuracy()] strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) model.compile(optimizer, loss, metrics=metrics, distribute=strategy) @@ -459,62 +507,167 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', keras.metrics.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 10, distribution) + dataset = get_dataset(distribution) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) - # Test with validation data - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - validation_data=dataset, validation_steps=2) + model.predict(get_predict_dataset(distribution), steps=2) - def test_raise_error_for_stateful_metrics(self): + @combinations.generate(strategy_and_optimizer_combinations()) + def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): + with self.cached_session(): + model = get_model() - class ExampleStatefulMetric(keras.layers.Layer): + loss = 'mse' + model.compile(optimizer(), loss, distribute=distribution) - def __init__(self, name='true_positives', **kwargs): - super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) - self.stateful = True + dataset = get_dataset(distribution) - def __call__(self, y_true, y_pred): - return y_pred - y_true + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + def test_dataset_input_shape_validation(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae', ExampleStatefulMetric()] strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', '/device:GPU:0']) - with self.assertRaisesRegexp( - NotImplementedError, 'Stateful metrics are not supported with ' - 'DistributionStrategy.'): - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.tpu_strategy_one_step], + mode=['graph'])) + def test_dataset_input_shape_fully_defined(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + # Input shapes are not fully known. Batch dimension is unknown as we are + # not using the drop_remainder argument. + dataset = dataset.repeat(100).batch(10) + + with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.cached_session(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat().batch(8) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + evaluate_output = model.evaluate(dataset, steps=20) + self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + predict_dataset = predict_dataset.repeat().batch(5) + output = model.predict(predict_dataset, steps=10) + ref_output = np.ones((50, 1), dtype=np.float32) + self.assertArrayNear(output[0], ref_output, 1e-1) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.cached_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.cached_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) def test_unsupported_features(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' @@ -524,11 +677,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = get_dataset(strategy) # Test with validation split with self.assertRaisesRegexp( @@ -565,9 +714,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_calling_with_unsupported_predefined_callbacks(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' @@ -576,11 +723,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): '/device:GPU:0']) model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = get_dataset(strategy) def schedule(_): return 0.001 @@ -602,74 +745,8 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) - def test_dataset_input_shape_validation(self): - with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - - # Wrong input shape - inputs = np.zeros((10, 5), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - - with self.assertRaisesRegexp(ValueError, - 'expected input to have shape'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - - def test_learning_phase_value(self): - # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare - # meaningful values. Currently we don't pass the learning phase if the - # Lambda layer uses the learning phase. - with self.cached_session(): - x = keras.layers.Input(shape=(16,), name='input') - y = keras.layers.Dense(16)(x) - z = keras.layers.Dropout(0.9999)(y) - model = keras.Model(x, z) - - optimizer = gradient_descent.GradientDescentOptimizer(0.005) - loss = 'mse' - metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - - inputs = np.random.rand(10, 16) - targets = np.ones((10, 16), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(8) - - hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) - self.assertEqual(hist.history['acc'][0], 1) - - evaluate_output = model.evaluate(dataset, steps=20) - self.assertEqual(evaluate_output[1], 0) - - predict_output = model.predict(dataset, steps=1) - self.assertNotEqual(np.mean(predict_output), 0) - - -class LossMaskingWithDistributionStrategyTest(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. @@ -696,10 +773,10 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): self.assertEqual(hist.history['loss'][0], 0) -class NormalizationLayerWithDistributionStrategyTest( +class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -716,30 +793,72 @@ class NormalizationLayerWithDistributionStrategyTest( dataset = dataset.repeat(100) dataset = batch_wrapper(dataset, 32, distribution) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) - out = model.predict(dataset, steps=2) + out = model.predict(predict_dataset, steps=2) out -= keras.backend.eval(norm.beta) out /= keras.backend.eval(norm.gamma) np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class CorrectnessWithDistributionStrategyTest(test.TestCase, - parameterized.TestCase): +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) + def test_metric_correctness(self, distribution): + with self.cached_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + # Create identity model. + model = keras.Sequential() + model.add( + keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()], + distribute=distribution) + + batch_size = 64 + batch_size //= distribution.num_towers + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + + history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0]) + + @combinations.generate(strategy_combinations()) def test_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') num_samples = 10000 + + # Train and predict datasets are created with the same input numpy arrays. x_train = np.random.rand(num_samples, 1) y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + initial_weights = model.get_weights() + def fit_and_predict(with_distribution=None): - model = keras.Sequential() - model.add(keras.layers.Dense(1, input_shape=(1,))) + model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), @@ -751,18 +870,19 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase, train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - # Running only 100 steps instead of the full dataset to keep test - # duration small. - model.fit(x=train_dataset, epochs=1, steps_per_epoch=100) + # We have initialized the model to the same weight for the distribution + # and non-distribution run. If you want to initialize the model to + # random weights for each run, you need to run the model through the + # entire dataset at least once to ensure that the weights converge to + # the same value. + model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) weights = model.get_weights() - x_predict = [[1.], [2.], [3.], [4.]] predict_batch_size = 4 if with_distribution: predict_batch_size //= with_distribution.num_towers - predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict, - x_predict)) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, distribution) predict_result = model.predict(predict_dataset, steps=1) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8163494c8ed2c5c2164df2e731d09ebb794414cd..ae4189eb1cb217f8a209b57f91a0ddb82e63dcd9 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -35,7 +36,8 @@ def _labeled_dataset_fn(): # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True return dataset_ops.Dataset.range(1000).map( - lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + lambda x: {"labels": x % 5, "predictions": x % 3}).batch( + 4, drop_remainder=True) def _boolean_dataset_fn(): @@ -47,7 +49,8 @@ def _boolean_dataset_fn(): # F, T -> FP; T, F -> FN; F, F -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [True, True, False, False]}).repeat().batch(3) + "predictions": [True, True, False, False]}).repeat().batch( + 3, drop_remainder=True) def _threshold_dataset_fn(): @@ -59,7 +62,8 @@ def _threshold_dataset_fn(): # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN return dataset_ops.Dataset.from_tensor_slices({ "labels": [True, False, True, False], - "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch( + 3, drop_remainder=True) def _regression_dataset_fn(): @@ -79,6 +83,12 @@ def all_combinations(): mode=["graph"]) +def tpu_combinations(): + return combinations.combine(distribution=[combinations.tpu_strategy_one_step, + combinations.tpu_strategy], + mode=["graph"]) + + # TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, # metrics.precision_at_k class MetricsV1Test(test.TestCase, parameterized.TestCase): @@ -87,42 +97,50 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() - value, update = distribution.call_for_each_tower( - metric_fn, iterator.get_next()) - update = distribution.group(update) + if isinstance(distribution, tpu_strategy.TPUStrategy): + def step_fn(ctx, inputs): + value, update = distribution.call_for_each_tower( + metric_fn, inputs) + ctx.set_non_tensor_output(name="value", output=value) + return distribution.group(update) + + ctx = distribution.run_steps_on_dataset( + step_fn, iterator, iterations=distribution.steps_per_run) + update = ctx.run_op + value = ctx.non_tensor_outputs["value"] + # In each run, we run multiple steps, and each steps consumes as many + # batches as number of towers. + batches_per_update = ( + distribution.num_towers * distribution.steps_per_run) + else: + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) - # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_towers" with "1". - batches_per_update = distribution.num_towers - - # Update variables using the first `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), - 0.001, msg="After first update") - - # Update variables using the second `num_towers` batches. - self.evaluate(update) - self.assertAllClose(expected_fn(2 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After second update") - - if batches_per_update == 1: # Consume 4 input batches - self.evaluate(update) - self.assertAllClose(expected_fn(3 * batches_per_update), - self.evaluate(value), - 0.001, - msg="After third update") + + batches_consumed = 0 + for i in range(4): self.evaluate(update) - self.assertAllClose(expected_fn(4 * batches_per_update), + batches_consumed += batches_per_update + self.assertAllClose(expected_fn(batches_consumed), self.evaluate(value), 0.001, - msg="After fourth update") + msg="After update #" + str(i+1)) + if batches_consumed >= 4: # Consume 4 input batches in total. + break - @combinations.generate(all_combinations()) + self.evaluate(distribution.finalize()) + + @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): - return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch( + 4, drop_remainder=True) def _expected_fn(num_batches): # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. @@ -130,7 +148,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAccuracy(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -143,6 +161,8 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added + # for TPUMirroredVariable. @combinations.generate(all_combinations()) def testMeanPerClassAccuracy(self, distribution): def _metric_fn(x): @@ -161,6 +181,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + # NOTE(priyag): This metric doesn't work on TPUs yet. @combinations.generate(all_combinations()) def testMeanIOU(self, distribution): def _metric_fn(x): @@ -179,7 +200,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanTensor(self, distribution): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) @@ -198,7 +219,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCROC(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -212,7 +233,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testAUCPR(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -226,7 +247,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -239,7 +260,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalseNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -252,7 +273,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegatives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -265,7 +286,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTrueNegativesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -278,7 +299,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -291,7 +312,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testFalsePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -304,7 +325,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositives(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -317,7 +338,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testTruePositivesAtThresholds(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -330,7 +351,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecision(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -343,7 +364,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testPrecisionAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -356,7 +377,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecall(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -369,7 +390,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRecallAtThreshold(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -382,7 +403,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] @@ -395,7 +416,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): self._test_metric( distribution, _regression_dataset_fn, _metric_fn, _expected_fn) - @combinations.generate(all_combinations()) + @combinations.generate(all_combinations() + tpu_combinations()) def testRootMeanSquaredError(self, distribution): def _metric_fn(x): labels = x["labels"] diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index ba147e78241e5ab45809e498e00debd45a2c49b4..60e134055ff3bd65ce717c5eb48168b07de4515c 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -179,11 +179,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], - "Adam": [ - "dense/kernel", "dense/bias", "beta1_power", "beta2_power", - "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", - "dense/bias/Adam_1" - ], "Adagrad": [ "dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad", "dense/bias" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 0c6805d68218029abcad784b476b76bf3d368a9f..0f82508428a58fb671cef25c97ca5880ebb38e83 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -293,7 +293,8 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) for v in index.values(): - l.remove(v) + if v in l: + l.remove(v) g.add_to_collections(collections, result) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) @@ -318,12 +319,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: - * **In-graph replication**: the `client` creates a single `tf.Graph` that + + * **In-graph replication**: the `client` creates a single `tf.Graph` that specifies tasks for devices on all workers. The `client` then creates a client session which will talk to the `master` service of a `worker`. Then the `master` will partition the graph and distribute the work to all participating workers. - * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one physical machine. We will have multiple `worker`s with different `task` index. They all do similar things except for one worker checkpointing model variables, writing summaries, etc. in addition to its ordinary work. @@ -347,6 +349,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input data to devices. + auto_shard_dataset: whether to auto-shard the dataset when there are + multiple workers. """ def __init__(self, @@ -354,11 +358,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_tower_ops=None, - prefetch_on_device=None): + prefetch_on_device=None, + auto_shard_dataset=False): super(MirroredStrategy, self).__init__() self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + self._auto_shard_dataset = auto_shard_dataset # Rememeber num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( @@ -456,16 +462,20 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) - else: - def initial_value_fn(device=d): + def initial_value_fn(device=d): + if context.executing_eagerly(): + init_value = index[devices[0]].value() + return array_ops.identity(init_value) + else: with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) - kwargs["initial_value"] = initial_value_fn + init_value = index[devices[0]].initial_value + return array_ops.identity(init_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + # Don't record operations (e.g. other variable reads) during + # variable creation. + with tape.stop_recording(): + v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v return index @@ -477,7 +487,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) + self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, @@ -623,9 +633,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. assert isinstance(var, values.DistributedVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) @@ -634,10 +646,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): assert isinstance(colocate_with, list) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. # TODO(josh11b): In eager mode, use one thread per device. updates = {} for d in colocate_with: @@ -645,7 +659,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): updates[d] = fn(*values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.regroup(updates, values.Mirrored) + return values.update_regroup(self, updates, should_group) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index c6894e901326ec0e1d9b60ff736134372ee0494a..ed36639ce86e891544edb644150c5d31fe610b4f 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import sys +import numpy as np + from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib @@ -34,7 +36,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl @@ -43,6 +48,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -826,7 +833,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with dist.scope(): ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) - update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) # Initialize variables. self.evaluate(variables.global_variables_initializer()) @@ -1245,6 +1252,22 @@ class MockModel(object): return x +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name="") + self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones", + bias_initializer="ones") + + def call(self, inputs, training=True): + inputs = array_ops.ones([1, 10]) + return self.fc(inputs) + + class MirroredStrategyDefunTest(test.TestCase): def _skip_eager_if_gpus_less_than(self, num_gpus): @@ -1271,7 +1294,17 @@ class MirroredStrategyDefunTest(test.TestCase): self.evaluate(device_result)) for defun in defuns: - self.assertEqual(set(mock_model.variables), set(defun.variables)) + # PolymorphicFunctions are specialized to the current device stack, so + # call_for_each has one trace per device. To check that the expected set + # of variables was accessed on each trace, we first retrieve each + # device-specific graph function. + per_device_graph_functions = dist.call_for_each_tower( + defun.get_concrete_function, + mock_model, *inputs, run_concurrently=False) + for device in devices: + graph_function = per_device_graph_functions.get(device=device) + self.assertEqual(set(mock_model.variables), + set(graph_function.graph.variables)) @test_util.run_in_graph_and_eager_modes() def testVariableInDefun(self): @@ -1355,6 +1388,41 @@ class MirroredStrategyDefunTest(test.TestCase): "GPU:0": 3.0 * 1.25}) self._call_and_check(fn1, [factors], expected_result, [fn1]) + @test_util.run_in_graph_and_eager_modes() + def testTrain(self): + self._skip_eager_if_gpus_less_than(1) + + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MiniModel() + mock_model.call = function.defun(mock_model.call) + + def loss_fn(ctx): + del ctx + return mock_model(array_ops.ones([1, 10])) + + gradients_fn = backprop.implicit_grad(loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + grads_and_vars = dist.call_for_each_tower( + gradients_fn, None, run_concurrently=False) + + optimizer = gradient_descent.GradientDescentOptimizer(0.25) + update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(update_ops) + + updated_var_values = self.evaluate(mock_model.variables) + # All variables start at 1.0 and get two updates of 0.25. + self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0]) + self.assertAllEqual([0.5], updated_var_values[1]) + + class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py new file mode 100644 index 0000000000000000000000000000000000000000..119352ad9195dc51201863f34aef19cb3289e635 --- /dev/null +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training.moving_averages when using a DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages + + +all_combinations = combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu], + mode=["graph"]) + + +class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): + + @combinations.generate(all_combinations) + def testTowerModeWithoutZeroDebias(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testTowerMode(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([0.0, 0.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average(var, val, decay) + return var, assign.op + + with distribution.scope(), self.cached_session() as sess: + var, assign_op = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(distribution.unwrap(assign_op)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + self.assertAllClose(average_val, var.eval()) + + @combinations.generate(all_combinations) + def testCrossTowerWithoutZeroDebias(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0, 2.0]) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(assign) + average_val = [1.0, 2.0] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + # Also try assign.op. + sess.run(assign.op) + orig_weight = 0.25 * 0.25 + val_weight = 1.0 - orig_weight + self.assertAllClose( + [10.0 * orig_weight + average_val[0] * val_weight, + 11.0 * orig_weight + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testCrossTower(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([0.0, 0.0]) + val = array_ops.placeholder(dtypes.float32) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average(var, val, decay) + + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(assign, feed_dict={val: [1.0, 2.0]}) + self.assertAllClose([1.0, 2.0], var.eval()) + + # Also try assign.op. + sess.run(assign.op, feed_dict={val: [10.0, 0.0]}) + self.assertAllClose( + [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0), + (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], + var.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 18b4503eff4c7e83e8b98a6d71893dee15c19898..9f92ba7dde5fc2798201cef2238bcc4b20b698a8 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -36,9 +36,29 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +ASSIGNED_PORTS = set() +lock = threading.Lock() + + +def pick_unused_port(): + """Returns an unused and unassigned local port.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + global ASSIGNED_PORTS + with lock: + while True: + port = portpicker.pick_unused_port() + if port > 10000 and port not in ASSIGNED_PORTS: + ASSIGNED_PORTS.add(port) + logging.info('Using local port %r', port) + return port + + def _create_cluster(num_workers, num_ps, has_chief=False, @@ -49,8 +69,8 @@ def _create_cluster(num_workers, """Creates and starts local servers and returns the cluster_spec dict.""" if _portpicker_import_error: raise _portpicker_import_error # pylint: disable=raising-bad-type - worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] - ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + worker_ports = [pick_unused_port() for _ in range(num_workers)] + ps_ports = [pick_unused_port() for _ in range(num_ps)] cluster_dict = {} if num_workers > 0: @@ -58,9 +78,9 @@ def _create_cluster(num_workers, if num_ps > 0: cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] if has_eval: - cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()] + cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] if has_chief: - cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()] + cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()] cs = server_lib.ClusterSpec(cluster_dict) @@ -139,11 +159,36 @@ def create_in_process_cluster(num_workers, num_workers, num_ps=num_ps, has_chief=has_chief, + has_eval=has_eval, worker_config=worker_config, ps_config=ps_config, protocol='grpc') +def create_cluster_spec(has_chief=False, + num_workers=1, + num_ps=0, + has_eval=False): + """Create a cluster spec with tasks with unused local ports.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + cluster_spec = {} + if has_chief: + cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] + if num_workers: + cluster_spec['worker'] = [ + 'localhost:%s' % pick_unused_port() for _ in range(num_workers) + ] + if num_ps: + cluster_spec['ps'] = [ + 'localhost:%s' % pick_unused_port() for _ in range(num_ps) + ] + if has_eval: + cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] + return cluster_spec + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 23b220f64b843a83aba3f9867b61415b70f19668..f5259190485e701c190beb49220caff743f8fdcb 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): else: assert False - def _update(self, var, fn, *args, **kwargs): - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(var, *args, **kwargs) + def _update(self, var, options, fn, *args, **kwargs): + # The implementations of _update() and _update_non_slot() are identical + # except _update() passes `var` as the first argument to `fn()`. + return self._update_non_slot(var, options, fn, var, *args, **kwargs) - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): del colocate_with + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 1125d027f64420863386d4fbd9db5564a5847825..6ddd91507bf86e8b0cf710a2340fd61abcdebe71 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, fn, *args, **kwargs): + def _update(self, var, options, fn, *args, **kwargs): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - return fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) + result = fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + if should_group: + return result + else: + return nest.map_structure(self._unwrap, result) def _unwrap(self, val): if isinstance(val, values.DistributedValues): diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 12789e0bc9f1c89ef8d57c40a978e2bb9471997b..9c112e4f851b5e5e6f65c0bd9d9564420f8d4446 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -262,7 +262,9 @@ class ParameterServerStrategyTestBase( h = f + 1.0 self.assertEqual( device_util.canonicalize(u.device), tower_variable_device) - self.assertEqual(device_util.canonicalize(x.device), h.device) + self.assertEqual( + device_util.canonicalize(x.device), + device_util.canonicalize(h.device)) return y_add, z_add, f y, z, f = d.call_for_each_tower(model_fn) @@ -395,7 +397,8 @@ class ParameterServerStrategyTestBase( # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies( + d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 1ff60c076226299a89060a295c1cc0c50817b861..d48aa9c89bc894a6afc4aab8b60fabc52a06b198 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -19,9 +19,7 @@ from __future__ import print_function import warnings -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest as data_nest @@ -30,6 +28,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import nest @@ -42,10 +41,9 @@ class _PrefetchToDeviceIterator(object): one_shot: If true, we make a one shot iterator that's already initialized. devices: Devices on which to prefetch. buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be - shared under the given name across multiple sessions that share the - same devices (e.g. when using a remote server). Only used if one_shot - is False. + shared_name: (Optional.) If non-empty, the returned iterator will be shared + under the given name across multiple sessions that share the same devices + (e.g. when using a remote server). Only used if one_shot is False. Returns: An Iterator type object. @@ -82,7 +80,7 @@ class _PrefetchToDeviceIterator(object): ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) - target_device = gen_dataset_ops.iterator_get_device( + target_device = ged_ops.experimental_iterator_get_device( self._input_iterator._iterator_resource) self._buffering_resources = [] for device in nest.flatten(self._devices): @@ -102,7 +100,8 @@ class _PrefetchToDeviceIterator(object): reset_ops = [] for buffer_resource in self._buffering_resources: reset_ops.append( - prefetching_ops.function_buffering_resource_reset(buffer_resource)) + ged_ops.experimental_function_buffering_resource_reset( + buffer_resource)) with ops.control_dependencies(reset_ops): self._initializer = self._input_iterator.make_initializer( self._input_dataset) @@ -118,10 +117,11 @@ class _PrefetchToDeviceIterator(object): # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: - flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + flat_ret = ged_ops.experimental_function_buffering_resource_get_next( buffer_resource, - output_types=data_nest.flatten(sparse.as_dense_types( - self.output_types, self.output_classes)), name=name) + output_types=data_nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), @@ -152,13 +152,16 @@ class _PrefetchToDeviceIterator(object): @property def output_types(self): return self._input_dataset.output_types + + # pylint: enable=protected-access -class _PrefetchToDeviceDataset(dataset_ops.Dataset): +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): """A `Dataset` whose iterator prefetches elements to other device(s).""" def __init__(self, input_dataset, devices, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._devices = devices self._buffer_size = buffer_size if buffer_size is not None else 1 @@ -222,6 +225,7 @@ def prefetch_to_devices(devices, buffer_size=None): A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ + def _apply_fn(dataset): return _PrefetchToDeviceDataset(dataset, devices, buffer_size) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py index bb10b546a1907bba26cd0d7e7c5308420adbaf3f..16799104e8112f4391152c0cf2a15af81f8c2c9d 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -55,14 +55,14 @@ class PrefetchingOpsV2Test(test.TestCase): next_element = iterator.get_next() output = [] + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: - for _ in range(5): + for _ in range(4): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) - self.assertEquals(set(range(10)), set(output)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + self.assertEquals(set(range(8)), set(output)) def testPrefetchToTwoDevicesWithReinit(self): if not test_util.is_gpu_available(): @@ -75,14 +75,14 @@ class PrefetchingOpsV2Test(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: sess.run(iterator.initializer) - for _ in range(5): - sess.run(next_element) - with self.assertRaises(errors.OutOfRangeError): + for _ in range(4): sess.run(next_element) sess.run(iterator.initializer) - for _ in range(5): + for _ in range(4): sess.run(next_element) diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 5aa19cf6a9f8411120ed929cecaf93dda6c9edf2..09b351ffa4165656e2fc9666ab4b7725ef061f50 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn, def dataset_fn(): dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() - # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be + # TODO(isaprykin): batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. - return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) + return dataset.batch(1, drop_remainder=True) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 5d498fb629d4a381f56aa7b2db95b09da9010a78..fd280f5754b34170cdd6b948236138d0e77dd8bc 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list @@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase): with ops.control_dependencies([fetched]): g = d.reduce( variable_scope.VariableAggregation.SUM, g, destinations=v) - with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + with ops.control_dependencies(d.update( + v, update, g, grouped=False)): after_list.append(d.read_var(v)) return before_list, after_list diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 6ba83976fcd47fe1680992fbbd5bb56ffa68071d..1d9e299b38409b874610765e54fa0052fafd5f4b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -29,6 +29,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -37,9 +38,13 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest +_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" + + def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -56,6 +61,58 @@ def get_tpu_system_metadata(tpu_cluster_resolver): return tpu_system_metadata +# TODO(jhseu): Deduplicate with MirroredStrategy? +def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, + **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the TPUMirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # TODO(jhseu): Should we have different behavior for different + # synchronization settings? + + # Get aggregation value + # TODO(jhseu): Support aggregation in a tower context. + aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) + if aggregation not in [ + vs.VariableAggregation.NONE, + vs.VariableAggregation.SUM, + vs.VariableAggregation.MEAN, + vs.VariableAggregation.ONLY_FIRST_TOWER, + ]: + raise ValueError("Invalid variable aggregation mode: {} for variable: {}" + .format(aggregation, kwargs["name"])) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + result = values.TPUMirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + +# TODO(jhseu): Stop inheriting from OneDeviceStrategy. class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" @@ -75,17 +132,28 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the # master node fetched from the cluster resolver. - super(TPUStrategy, self).__init__('/device:CPU:0') + super(TPUStrategy, self).__init__("/device:CPU:0") self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override self._num_cores_override = num_cores + # TODO(jhseu): Switch to DeviceAssignment to support pods and model + # parallelism. + device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) + if "device:TPU:" in d.name} + self._device_index = values.PerDevice(device_map) + self._tpu_devices = sorted(device_map.keys()) + # Only create variables for the number of towers we're running. + self._tpu_devices = self._tpu_devices[:self.num_towers] + # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + self._require_static_shapes = True + def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -158,7 +226,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' - 'dataset.apply(map_and_batch(..., drop_remainder=True)).') + 'dataset.batch(..., drop_remainder=True).') types = nest.flatten(iterator.output_types) enqueue_ops = [ @@ -231,6 +299,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. + # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] @@ -239,6 +308,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return ctx def _call_for_each_tower(self, fn, *args, **kwargs): + # TODO(jhseu): Consider making it so call_for_each_tower implies that we're + # in a tpu.rewrite(), and update TPUMirroredVariable accordingly. kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access return fn(*args, **kwargs) @@ -248,7 +319,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError('Eager mode not supported in TPUStrategy.') else: - return [tpu.initialize_system()] + # TODO(jhseu): We need this hack because DistributionStrategies must be + # pickleable for copy.deepcopy(). Remove when initialize_system goes away. + graph = ops.get_default_graph() + tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) + if tpu_init: + return tpu_init + graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, + tpu.initialize_system()) + return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) def finalize(self): if context.executing_eagerly(): @@ -257,21 +336,53 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): else: return [tpu.shutdown_system()] + def _get_devices_from(self, colocate_with=None): + # TODO(jhseu): Change this when we support model parallelism. + return self._tpu_devices + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) + else: + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.TPUMirroredVariable) + index[d] = v + return index + + return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) + def _reduce(self, aggregation, value, destinations): - graph = ops.get_default_graph() - cf_context = graph._get_control_flow_context() # pylint: disable=protected-access - # If we're inside the ReplicateContext, reduction should be done using - # CrossReplicaSum while outside we can directly use an add_n op. - while cf_context: - if isinstance(cf_context, tpu.TPUReplicateContext): - if aggregation == vs.VariableAggregation.MEAN: - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_towers) - elif aggregation != vs.VariableAggregation.SUM: - raise NotImplementedError( - 'Currently only support sum & mean in TPUStrategy.') - return tpu_ops.cross_replica_sum(value) - cf_context = cf_context.outer_context + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + elif aggregation != vs.VariableAggregation.SUM: + raise NotImplementedError( + "Currently only support sum & mean in TPUStrategy.") + return tpu_ops.cross_replica_sum(value) # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is @@ -290,10 +401,46 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return output * (1. / len(value)) return output - def _unwrap(self, value): - if isinstance(value, list): - return value - return [value] + def _update(self, var, options, fn, *args, **kwargs): + assert isinstance(var, values.TPUMirroredVariable) + should_group = options.pop("grouped") + assert not options # Validate that we are processing all of the options. + + if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access + if should_group: + return fn(var, *args, **kwargs) + else: + return [fn(var, *args, **kwargs)] + + # Otherwise, we revert to MirroredStrategy behavior and update each variable + # directly. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.update_regroup(self, updates, should_group) + + # TODO(josh11b): Need to implement _update_non_slot()! + + def read_var(self, var): + assert isinstance(var, values.TPUMirroredVariable) + return var.read_value() + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + return [val.get(device=d) for d in sorted(val.devices)] + elif isinstance(val, list): + # TODO(josh11b): We need to remove this case; per device values should + # be represented using a PerDevice wrapper instead of a list with + # one entry per device. + return val + return [val] + @property def num_towers(self): @@ -307,6 +454,30 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def num_towers_per_host(self): return self._tpu_metadata.num_of_cores_per_host + @property + def between_graph(self): + return False + + @property + def should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + + @property + def worker_devices(self): + return self._tpu_devices + + @property + def parameter_devices(self): + return self._tpu_devices + def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' @@ -324,4 +495,3 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): cluster_spec = self._tpu_cluster_resolver.cluster_spec() if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index fafa6384a1eb84102d6e99a61414767b590ca457..472cb4230c5155369ccf05eef2f82f86f8881bf2 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -22,17 +22,20 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import weakref import six from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context +from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as variables_lib @@ -363,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. strategy = distribution_strategy_context.get_distribution_strategy() - updates = strategy.update(self, f, *args, **kwargs) - grouped = strategy.group(updates) - if isinstance(updates, DistributedValues) and updates.is_tensor_like: - # Make sure we run all updates. Without this, something like - # session.run(mirrored_var.assign*(...)) may only update one tower. - index = {} - for d in updates.devices: - with ops.device(d), ops.control_dependencies([grouped]): - index[d] = array_ops.identity(updates.get(d)) - return Mirrored(index) - else: - return grouped + return strategy.update(self, f, *args, **kwargs) else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -453,6 +445,397 @@ ops.register_tensor_conversion_function(MirroredVariable, _tensor_conversion_mirrored) +def _enclosing_tpu_context(): + # pylint: disable=protected-access + tpu_context = ops.get_default_graph()._get_control_flow_context() + # pylint: enable=protected-access + while tpu_context is not None and not isinstance( + tpu_context, control_flow_ops.XLAControlFlowContext): + tpu_context = tpu_context.outer_context + return tpu_context + + +# TODO(jhseu): Deduplicate code. We copy code because we don't want to +# inherit from DistributedDelegate. DistributedDelegate will not work in a +# tpu.replicate() because it assumes that you're in a device context where you +# can operate on a single version of the variable, but a tpu.replicate() +# operates on all variables and is replicated during a rewrite pass. +class TPUMirroredVariable(checkpointable.CheckpointableBase): + """Holds a map from device to TPU variables whose values are kept in sync.""" + + def __init__(self, index, primary_var, aggregation): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + self._primary_var = primary_var + self._common_name = self._primary_var.name.split(":")[0] + self._aggregation = aggregation + # Needed for GradientTape + self._trainable = self._primary_var.trainable + # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s + # initializer is composed of the initializers of the components variables. + # However, in some cases, such as when restoring from a checkpoint, we may + # set the _initializer_op property on the entire `TPUMirroredVariable`. + self._initializer_op = None + + def _get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribution_strategy_context.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + return self._get_cross_tower() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.read_value() + o + def __radd__(self, o): return o + self.read_value() + def __sub__(self, o): return self.read_value() - o + def __rsub__(self, o): return o - self.read_value() + def __mul__(self, o): return self.read_value() * o + def __rmul__(self, o): return o * self.read_value() + def __truediv__(self, o): return self.read_value() / o + def __rtruediv__(self, o): return o / self.read_value() + def __floordiv__(self, o): return self.read_value() // o + def __rfloordiv__(self, o): return o // self.read_value() + def __mod__(self, o): return self.read_value() % o + def __rmod__(self, o): return o % self.read_value() + def __lt__(self, o): return self.read_value() < o + def __le__(self, o): return self.read_value() <= o + def __gt__(self, o): return self.read_value() > o + def __ge__(self, o): return self.read_value() >= o + def __and__(self, o): return self.read_value() & o + def __rand__(self, o): return o & self.read_value() + def __or__(self, o): return self.read_value() | o + def __ror__(self, o): return o | self.read_value() + def __xor__(self, o): return self.read_value() ^ o + def __rxor__(self, o): return o ^ self.read_value() + def __getitem__(self, o): return self.read_value()[o] + def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo) + def __rpow__(self, o): return pow(o, self.read_value()) + def __invert__(self): return ~self.read_value() + def __neg__(self): return -self.read_value() + def __abs__(self): return abs(self.read_value()) + + def __div__(self, o): + try: + return self.read_value().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.read_value().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.read_value().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.read_value().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + @property + def handle(self): + # If we're in a tpu.rewrite(), return the replicated handle. + tpu_context = _enclosing_tpu_context() + if tpu_context is not None: + return tpu_context.get_replicated_var_handle( + self._common_name, nest.flatten(self._index)) + + device = distribute_lib.get_update_device() + if device is None: + return self._primary_var.handle + device = device_util.canonicalize(device) + try: + return self._index[device].handle + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) + + @property + def device(self): + return self._get().device + + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": + raise ValueError("You may only assign to a TPUMirroredVariable within a " + "TPUStrategy.") + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + if _enclosing_tpu_context() is not None: + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self._get(device=update_device) + return f(v, *args, **kwargs) + + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + _assert_tower_context() + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "TPUMirroredVariable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + + @contextlib.contextmanager + def _handle_graph(self, handle): + # Note: might have an eager tensor but not be executing eagerly when + # building functions. + if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) + or ops.has_default_graph()): + yield + else: + with handle.graph.as_default(): + yield + + @property + def trainable(self): + return self._trainable + + def _read_variable_op(self, parent_op=None): + if self.trainable: + tape.variable_accessed(self) + if parent_op is not None: + with ops.control_dependencies([parent_op]): + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + return gen_resource_variable_ops.read_variable_op( + self.handle, self.dtype) + + def read_value(self): + return self._read_variable_op() + + def assign_sub(self, *args, **kwargs): + def assign_sub_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_sub_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_sub_fn, *args, **kwargs) + + def assign_add(self, *args, **kwargs): + def assign_add_fn(var, delta, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_add_variable_op( + var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_add_fn, *args, **kwargs) + + def assign(self, *args, **kwargs): + def assign_fn(var, value, **kw): + name = kw.pop("name", None) + read_value = kw.pop("read_value", True) + with self._handle_graph(var.handle): + op = gen_resource_variable_ops.assign_variable_op( + var.handle, ops.convert_to_tensor(value, dtype=self.dtype), + name=name) + if read_value: + return self._read_variable_op(parent_op=op) + return op + + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation + + @property + def constraint(self): + return None + + @property + def initializer(self): + if self._initializer_op: + init_op = self._initializer_op + else: + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return self._primary_var + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribution_strategy_context.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self._read_variable_op() + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + # Needed to pass ResourceVariable checks. + @property + def op(self): + return self._primary_var.op + + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # pylint: disable=protected-access + if _enclosing_tpu_context() is None: + return self._get()._dense_var_to_tensor(dtype, name, as_ref) + # pylint: enable=protected-access + if dtype is not None and dtype != self.dtype: + raise NotImplementedError + if as_ref: + return self.handle + else: + return self.read_value() + + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # TODO(jhseu): Do we need TPU context implementation? + + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = nest.flatten(self._index) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False): + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + + +ops.register_tensor_conversion_function(TPUMirroredVariable, + _tensor_conversion_tpu_mirrored) +ops.register_dense_tensor_like_type(TPUMirroredVariable) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -668,6 +1051,29 @@ def select_device_mirrored(device, structured): return nest.map_structure(_get_mirrored, structured) +def update_regroup(strategy, updates, should_group): + """Regroup for an update, with dependencies to ensure all updates execute.""" + regrouped = regroup(updates, Mirrored) + if not should_group: + return nest.map_structure(strategy.unwrap, regrouped) + grouped_flat = [] + for u in nest.flatten(regrouped): + if isinstance(u, DistributedValues): + g = strategy.group(u) + if u.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(strategy.update(...)) may only update one tower. + index = {} + for d in u.devices: + with ops.device(d), ops.control_dependencies([g]): + index[d] = array_ops.identity(u.get(d)) + g = Mirrored(index) + else: + g = u + grouped_flat.append(g) + return nest.pack_sequence_as(regrouped, grouped_flat) + + class PerDeviceDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" @@ -726,14 +1132,14 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( - dataset_iterator, self._devices, self._prefetch_on_device) + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( - dataset_iterator, self._devices, self._prefetch_on_device) + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) class MultiWorkerDataIterator(object): @@ -793,7 +1199,8 @@ class MultiWorkerDataset(object): eager mode. """ - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, + auto_shard=False): """Initialize the MultiWorkerDataset object. Args: @@ -801,6 +1208,7 @@ class MultiWorkerDataset(object): worker_device_map: a dict mapping from each worker to a list of devices that belong to this worker. prefetch_on_device: whether to prefetch to devices. + auto_shard: whether to auto-shard the dataset. """ self._worker_device_map = worker_device_map self._datasets = {} @@ -810,8 +1218,9 @@ class MultiWorkerDataset(object): six.iteritems(worker_device_map)): with ops.device(worker): worker_input = dataset_fn() - worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) + if auto_shard: + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, prefetch_on_device=prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 15a85a28f5bff1dffeda0ed1a47080b49ce50e11..121d2fbb3fbccd913599a581b3de9850ab33eae0 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -375,8 +375,9 @@ class PerDeviceDatasetTest(test.TestCase): combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() - combined_actual.extend(self.evaluate([ - values.select_device(d, next_element) for d in devices])) + combined_actual.extend( + self.evaluate( + [values.select_device(d, next_element) for d in devices])) combined_expected.extend(expected_value) self.assertEqual(set(combined_expected), set(combined_actual)) @@ -640,7 +641,7 @@ class MirroredVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, devices, mirrored = _make_mirrored() # Overwrite the initial values. @@ -743,7 +744,7 @@ class MirroredVariableTest(test.TestCase): if context.num_gpus() < 1 or context.executing_eagerly(): self.skipTest("A GPU is not available for this test or it's eager mode.") - with self.test_session( + with self.session( graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( ["/device:GPU:0"]).scope(): with ops.device("/device:GPU:0"): @@ -826,7 +827,7 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. @@ -849,7 +850,7 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.test_session() as sess: + with self.cached_session(config=self.config) as sess: v, tower_local = _make_tower_local( variable_scope.VariableAggregation.MEAN) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 9aadc634da5a7591747a4f651cdb45376393402d..60f6b90edcb71f04bca29b90744db201e83cd545 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -25,7 +25,6 @@ py_library( "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:clip_ops", @@ -61,7 +60,6 @@ py_library( ":bijectors_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/learn", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", @@ -301,7 +299,7 @@ cuda_py_test( cuda_py_test( name = "mvn_diag_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_diag_test.py"], additional_deps = [ ":distributions_py", @@ -706,8 +704,8 @@ cuda_py_test( ":bijectors_py", ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", @@ -722,8 +720,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", + "//tensorflow/python/ops/linalg", ], shard_count = 4, tags = ["noasan"], # times out, http://b/78588814 @@ -739,8 +737,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", @@ -794,8 +792,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -831,8 +829,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -852,8 +850,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -871,8 +869,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -907,8 +905,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -926,10 +924,10 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/ops/linalg", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", @@ -945,8 +943,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -964,8 +962,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -983,8 +981,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1002,8 +1000,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1021,8 +1019,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1040,8 +1038,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1075,8 +1073,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1126,8 +1124,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1161,8 +1159,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1180,8 +1178,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1201,8 +1199,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1221,8 +1219,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1240,8 +1238,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1259,8 +1257,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1278,8 +1276,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1297,8 +1295,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1316,8 +1314,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 5cec93c4df2e970f203253be6342bb292f296eb0..343eae3440e30f7d328cd214c5c2cc8208b310e2 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -18,69 +18,73 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member +from tensorflow.python.util import deprecation -from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops.autoregressive import * -from tensorflow.contrib.distributions.python.ops.batch_reshape import * -from tensorflow.contrib.distributions.python.ops.binomial import * -from tensorflow.contrib.distributions.python.ops.cauchy import * -from tensorflow.contrib.distributions.python.ops.chi2 import * -from tensorflow.contrib.distributions.python.ops.conditional_distribution import * -from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * -from tensorflow.contrib.distributions.python.ops.deterministic import * -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform -from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp -from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag -from tensorflow.contrib.distributions.python.ops.estimator import * -from tensorflow.contrib.distributions.python.ops.geometric import * -from tensorflow.contrib.distributions.python.ops.half_normal import * -from tensorflow.contrib.distributions.python.ops.independent import * -from tensorflow.contrib.distributions.python.ops.inverse_gamma import * -from tensorflow.contrib.distributions.python.ops.kumaraswamy import * -from tensorflow.contrib.distributions.python.ops.logistic import * -from tensorflow.contrib.distributions.python.ops.mixture import * -from tensorflow.contrib.distributions.python.ops.mixture_same_family import * -from tensorflow.contrib.distributions.python.ops.moving_stats import * -from tensorflow.contrib.distributions.python.ops.mvn_diag import * -from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * -from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * -from tensorflow.contrib.distributions.python.ops.mvn_tril import * -from tensorflow.contrib.distributions.python.ops.negative_binomial import * -from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * -from tensorflow.contrib.distributions.python.ops.onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.poisson import * -from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * -from tensorflow.contrib.distributions.python.ops.quantized_distribution import * -from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * -from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.sample_stats import * -from tensorflow.contrib.distributions.python.ops.seed_stream import * -from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * -from tensorflow.contrib.distributions.python.ops.test_util import * -from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * -from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * -from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * -from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * -from tensorflow.contrib.distributions.python.ops.wishart import * -from tensorflow.python.ops.distributions.bernoulli import * -from tensorflow.python.ops.distributions.beta import * -from tensorflow.python.ops.distributions.categorical import * -from tensorflow.python.ops.distributions.dirichlet import * -from tensorflow.python.ops.distributions.dirichlet_multinomial import * -from tensorflow.python.ops.distributions.distribution import * -from tensorflow.python.ops.distributions.exponential import * -from tensorflow.python.ops.distributions.gamma import * -from tensorflow.python.ops.distributions.kullback_leibler import * -from tensorflow.python.ops.distributions.laplace import * -from tensorflow.python.ops.distributions.multinomial import * -from tensorflow.python.ops.distributions.normal import * -from tensorflow.python.ops.distributions.student_t import * -from tensorflow.python.ops.distributions.transformed_distribution import * -from tensorflow.python.ops.distributions.uniform import * + +# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top + +with deprecation.silence(): + from tensorflow.contrib.distributions.python.ops import bijectors + from tensorflow.contrib.distributions.python.ops.autoregressive import * + from tensorflow.contrib.distributions.python.ops.batch_reshape import * + from tensorflow.contrib.distributions.python.ops.binomial import * + from tensorflow.contrib.distributions.python.ops.cauchy import * + from tensorflow.contrib.distributions.python.ops.chi2 import * + from tensorflow.contrib.distributions.python.ops.conditional_distribution import * + from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * + from tensorflow.contrib.distributions.python.ops.deterministic import * + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform + from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp + from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag + from tensorflow.contrib.distributions.python.ops.estimator import * + from tensorflow.contrib.distributions.python.ops.geometric import * + from tensorflow.contrib.distributions.python.ops.half_normal import * + from tensorflow.contrib.distributions.python.ops.independent import * + from tensorflow.contrib.distributions.python.ops.inverse_gamma import * + from tensorflow.contrib.distributions.python.ops.kumaraswamy import * + from tensorflow.contrib.distributions.python.ops.logistic import * + from tensorflow.contrib.distributions.python.ops.mixture import * + from tensorflow.contrib.distributions.python.ops.mixture_same_family import * + from tensorflow.contrib.distributions.python.ops.moving_stats import * + from tensorflow.contrib.distributions.python.ops.mvn_diag import * + from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * + from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * + from tensorflow.contrib.distributions.python.ops.mvn_tril import * + from tensorflow.contrib.distributions.python.ops.negative_binomial import * + from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * + from tensorflow.contrib.distributions.python.ops.onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.poisson import * + from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * + from tensorflow.contrib.distributions.python.ops.quantized_distribution import * + from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * + from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.sample_stats import * + from tensorflow.contrib.distributions.python.ops.seed_stream import * + from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * + from tensorflow.contrib.distributions.python.ops.test_util import * + from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * + from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * + from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * + from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * + from tensorflow.contrib.distributions.python.ops.wishart import * + from tensorflow.python.ops.distributions.bernoulli import * + from tensorflow.python.ops.distributions.beta import * + from tensorflow.python.ops.distributions.categorical import * + from tensorflow.python.ops.distributions.dirichlet import * + from tensorflow.python.ops.distributions.dirichlet_multinomial import * + from tensorflow.python.ops.distributions.distribution import * + from tensorflow.python.ops.distributions.exponential import * + from tensorflow.python.ops.distributions.gamma import * + from tensorflow.python.ops.distributions.kullback_leibler import * + from tensorflow.python.ops.distributions.laplace import * + from tensorflow.python.ops.distributions.multinomial import * + from tensorflow.python.ops.distributions.normal import * + from tensorflow.python.ops.distributions.student_t import * + from tensorflow.python.ops.distributions.transformed_distribution import * + from tensorflow.python.ops.distributions.uniform import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index 8dad80aa647f0c7d53685aed4025dd49ffa0f6d0..c32ea9ade73c3cfb285bb32ebb91908910c34c5c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -93,12 +93,12 @@ class SoftsignBijectorTest(test.TestCase): bijector.inverse_log_det_jacobian(y, event_ndims=1))) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = Softsign(validate_args=True) assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = Softsign(validate_args=True) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.linspace(-0.99, 0.99, 100).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index f073f51a6983c9ac016630bf1dba405c73db6354..9b9b3ce2dd9d42286d2d9657d5f00de8445261f0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -212,7 +212,7 @@ class DistributionTest(test.TestCase): def testStrWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) self.assertEqual( - ("tf.distributions.Normal(" + ("tfp.distributions.Normal(" "\"Normal/\", " "batch_shape=(), " "event_shape=(), " @@ -221,7 +221,7 @@ class DistributionTest(test.TestCase): chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") self.assertEqual( - ("tf.distributions.Chi2(" + ("tfp.distributions.Chi2(" "\"silly/\", " # What a silly name that is! "batch_shape=(2,), " "event_shape=(), " @@ -230,7 +230,7 @@ class DistributionTest(test.TestCase): exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) self.assertEqual( - ("tf.distributions.Exponential(\"Exponential/\", " + ("tfp.distributions.Exponential(\"Exponential/\", " # No batch shape. "event_shape=(), " "dtype=float32)"), @@ -240,7 +240,7 @@ class DistributionTest(test.TestCase): mvn_static = tfd.MultivariateNormalDiag( loc=np.zeros([2, 2]), name="MVN") self.assertEqual( - ("tf.distributions.MultivariateNormalDiag(" + ("tfp.distributions.MultivariateNormalDiag(" "\"MVN/\", " "batch_shape=(2,), " "event_shape=(2,), " @@ -251,7 +251,7 @@ class DistributionTest(test.TestCase): loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), name="MVN2") self.assertEqual( - ("tf.distributions.MultivariateNormalDiag(" + ("tfp.distributions.MultivariateNormalDiag(" "\"MVN2/\", " "batch_shape=(?,), " # Partially known. "event_shape=(3,), " @@ -261,7 +261,7 @@ class DistributionTest(test.TestCase): def testReprWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) self.assertEqual( - ("" " event_shape=()" @@ -290,7 +290,7 @@ class DistributionTest(test.TestCase): mvn_static = tfd.MultivariateNormalDiag( loc=np.zeros([2, 2]), name="MVN") self.assertEqual( - ("= 0, dtype=x.dtype) - - -def _det_large_enough_mask(x, det_bounds): - """Returns whether the input matches the given determinant limit. - - Args: - x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. - det_bounds: A floating-point `Tensor` that must broadcast to shape - `[B1, ..., Bn]`, giving the desired lower bound on the - determinants in `x`. - - Returns: - mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each - scalar is 1 if the corresponding matrix had determinant above - the corresponding bound, otherwise 0. - """ - # For the curious: I wonder whether it is possible and desirable to - # use a Cholesky decomposition-based algorithm for this, since the - # only matrices whose determinant this code cares about will be PSD. - # Didn't figure out how to code that in TensorFlow. - # - # Expert opinion is that it would be about twice as fast since - # Cholesky is roughly half the cost of Gaussian Elimination with - # Partial Pivoting. But this is less of an impact than the switch in - # _psd_mask. - return math_ops.cast( - linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) - - -def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): - """Returns a uniformly random `Tensor` of "correlation-like" matrices. - - A "correlation-like" matrix is a symmetric square matrix with all entries - between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, - the ones that are positive semi-definite are exactly the correlation - matrices. - - Args: - num_rows: Python `int` dimension of the correlation-like matrices. - batch_shape: `Tensor` or Python `tuple` of `int` shape of the - batch to return. - dtype: `dtype` of the `Tensor` to return. - seed: Random seed. - - Returns: - matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` - and dtype `dtype`. Each entry is in [-1, 1], and each matrix - along the bottom two dimensions is symmetric and has 1s on the - main diagonal. - """ - num_entries = num_rows * (num_rows + 1) / 2 - ones = array_ops.ones(shape=[num_entries], dtype=dtype) - # It seems wasteful to generate random values for the diagonal since - # I am going to throw them away, but `fill_triangular` fills the - # diagonal, so I probably need them. - # It's not impossible that it would be more efficient to just fill - # the whole matrix with random values instead of messing with - # `fill_triangular`. Then would need to filter almost half out with - # `matrix_band_part`. - unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) - tril = util.fill_triangular(unifs) - symmetric = tril + array_ops.matrix_transpose(tril) - diagonal_ones = array_ops.ones( - shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), - dtype=dtype) - return array_ops.matrix_set_diag(symmetric, diagonal_ones) - - -def correlation_matrix_volume_rejection_samples( - det_bounds, dim, sample_shape, dtype, seed): - """Returns rejection samples from trying to get good correlation matrices. - - The proposal being rejected from is the uniform distribution on - "correlation-like" matrices. We say a matrix is "correlation-like" - if it is a symmetric square matrix with all entries between -1 and 1 - (inclusive) and 1s on the main diagonal. Of these, the ones that - are positive semi-definite are exactly the correlation matrices. - - The rejection algorithm, then, is to sample a `Tensor` of - `sample_shape` correlation-like matrices of dimensions `dim` by - `dim`, and check each one for (i) being a correlation matrix (i.e., - PSD), and (ii) having determinant at least the corresponding entry - of `det_bounds`. - - Args: - det_bounds: A `Tensor` of lower bounds on the determinants of - acceptable matrices. The shape must broadcast with `sample_shape`. - dim: A Python `int` dimension of correlation matrices to sample. - sample_shape: Python `tuple` of `int` shape of the samples to - compute, excluding the two matrix dimensions. - dtype: The `dtype` in which to do the computation. - seed: Random seed. - - Returns: - weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the - corresponding matrix was not a correlation matrix, or had too - small of a determinant. Otherwise, the entry is the - multiplicative inverse of the density of proposing that matrix - uniformly, i.e., the volume of the set of `dim` by `dim` - correlation-like matrices. - volume: The volume of the set of `dim` by `dim` correlation-like - matrices. - """ - with ops.name_scope("rejection_sampler"): - rej_proposals = _uniform_correlation_like_matrix( - dim, sample_shape, dtype, seed=seed) - rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) - # The density of proposing any given point is 1 / rej_proposal_volume; - # The weight of that point should be scaled by - # 1 / density = rej_proposal_volume. - rej_weights = rej_proposal_volume * _psd_mask( - rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) - return rej_weights, rej_proposal_volume - - -def _clopper_pearson_confidence_interval(samples, error_rate): - """Computes a confidence interval for the mean of the given 1-D distribution. - - Assumes (and checks) that the given distribution is Bernoulli, i.e., - takes only two values. This licenses using the CDF of the binomial - distribution for the confidence, which is tighter (for extreme - probabilities) than the DKWM inequality. The method is known as the - [Clopper-Pearson method] - (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). - - Assumes: - - - The given samples were drawn iid from the distribution of interest. - - - The given distribution is a Bernoulli, i.e., supported only on - low and high. - - Guarantees: - - - The probability (over the randomness of drawing the given sample) - that the true mean is outside the returned interval is no more - than the given error_rate. - - Args: - samples: `np.ndarray` of samples drawn iid from the distribution - of interest. - error_rate: Python `float` admissible rate of mistakes. - - Returns: - low: Lower bound of confidence interval. - high: Upper bound of confidence interval. - - Raises: - ValueError: If `samples` has rank other than 1 (batch semantics - are not implemented), or if `samples` contains values other than - `low` or `high` (as that makes the distribution not Bernoulli). - """ - # TODO(b/78025336) Migrate this confidence interval function - # to statistical_testing.py. In order to do that - # - Get the binomial CDF from the Binomial distribution - # - Implement scalar root finding in TF. Batch bisection search - # shouldn't be too hard, and is definitely good enough for this - # problem. Batching the Brent algorithm (from scipy) that is used - # here may be more involved, but may also not be necessary---it's - # only used here because scipy made it convenient. In particular, - # robustness is more important than speed here, which may make - # bisection search actively better. - # - The rest is just a matter of rewriting in the appropriate style. - if optimize is None or stats is None: - raise ValueError( - "Scipy is required for computing Clopper-Pearson confidence intervals") - if len(samples.shape) != 1: - raise ValueError("Batch semantics not implemented") - n = len(samples) - low = np.amin(samples) - high = np.amax(samples) - successes = np.count_nonzero(samples - low) - failures = np.count_nonzero(samples - high) - if successes + failures != n: - uniques = np.unique(samples) - msg = ("Purportedly Bernoulli distribution had distinct samples" - " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) - raise ValueError(msg) - def p_small_enough(p): - prob = stats.binom.logcdf(successes, n, p) - return prob - np.log(error_rate / 2.) - def p_big_enough(p): - prob = stats.binom.logsf(successes, n, p) - return prob - np.log(error_rate / 2.) - high_p = optimize.brentq( - p_small_enough, float(successes) / n, 1., rtol=1e-9) - low_p = optimize.brentq( - p_big_enough, 0., float(successes) / n, rtol=1e-9) - low_interval = low + (high - low) * low_p - high_interval = low + (high - low) * high_p - return (low_interval, high_interval) - - -def compute_true_volumes( - det_bounds, dim, num_samples, error_rate=1e-6, seed=42): - """Returns confidence intervals for the desired correlation matrix volumes. - - The confidence intervals are computed by the [Clopper-Pearson method] - (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). - - Args: - det_bounds: A rank-1 numpy array of lower bounds on the - determinants of acceptable matrices. Entries must be unique. - dim: A Python `int` dimension of correlation matrices to sample. - num_samples: The number of samples to draw. - error_rate: The statistical significance of the returned - confidence intervals. The significance is broadcast: Each - returned interval separately may be incorrect with probability - (under the sample of correlation-like matrices drawn internally) - at most `error_rate`. - seed: Random seed. - - Returns: - bounds: A Python `dict` mapping each determinant bound to the low, high - tuple giving the confidence interval. - """ - bounds = {} - with session.Session() as sess: - rej_weights, _ = correlation_matrix_volume_rejection_samples( - det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) - rej_weights = sess.run(rej_weights) - for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): - template = ("Estimating volume of {}x{} correlation " - "matrices with determinant >= {}.") - print(template.format(dim, dim, det)) - sys.stdout.flush() - bounds[det] = _clopper_pearson_confidence_interval( - rw, error_rate=error_rate) - return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py deleted file mode 100644 index 8f99300e63871119800a42f122c8321e5986541a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for correlation_matrix_volumes_lib.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr -from tensorflow.contrib.distributions.python.ops import statistical_testing as st -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.platform import test - - -# NxN correlation matrices are determined by the N*(N-1)/2 -# lower-triangular entries. In addition to being between -1 and 1, -# they must also obey the constraint that the determinant of the -# resulting symmetric matrix is non-negative. In 2x2, we can even -# analytically compute the volume when the determinant is bounded to > -# epsilon, as that boils down to the one lower-triangular entry being -# less than 1 - epsilon in absolute value. -def two_by_two_volume(det_bound): - return 2 * np.sqrt(1.0 - det_bound) - - -# The post -# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ -# derives (with elementary calculus) that the volume (with respect to -# Lebesgue^3 measure) of the set of 3x3 correlation matrices is -# pi^2/2. The same result is also obtained by [1]. -def three_by_three_volume(): - return np.pi**2 / 2. - - -# The volume of the unconstrained set of correlation matrices is also -# the normalization constant of the LKJ distribution from [2]. As -# part of defining the distribution, that reference a derives general -# formula for this volume for all dimensions. A TensorFlow -# computation thereof gave the below result for 4x4: -def four_by_four_volume(): - # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) - return 11.6973076 - -# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of -# correlation matrices." The American Statistician, 48(4), 276-279. - -# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating -# random correlation matrices based on vines and extended onion -# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. - - -class CorrelationMatrixVolumesTest(test.TestCase): - - def testRejection2D(self): - num_samples = int(1e5) # Chosen for a small min detectable discrepancy - det_bounds = np.array( - [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) - exact_volumes = two_by_two_volume(det_bounds) - (rej_weights, - rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( - det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) - # shape of rej_weights: [num_samples, 9, 2, 2] - chk1 = st.assert_true_mean_equal_by_dkwm( - rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, - false_fail_rate=1e-6) - chk2 = check_ops.assert_less( - st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, low=0., high=rej_proposal_volume, - # Correct the false fail rate due to different broadcasting - false_fail_rate=1.1e-7, false_pass_rate=1e-6), - 0.036) - with ops.control_dependencies([chk1, chk2]): - rej_weights = array_ops.identity(rej_weights) - self.evaluate(rej_weights) - - def testRejection3D(self): - num_samples = int(1e5) # Chosen for a small min detectable discrepancy - det_bounds = np.array([0.0], dtype=np.float32) - exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) - (rej_weights, - rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( - det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) - # shape of rej_weights: [num_samples, 1, 3, 3] - chk1 = st.assert_true_mean_equal_by_dkwm( - rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, - false_fail_rate=1e-6) - chk2 = check_ops.assert_less( - st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, low=0., high=rej_proposal_volume, - false_fail_rate=1e-6, false_pass_rate=1e-6), - # Going for about a 3% relative error - 0.15) - with ops.control_dependencies([chk1, chk2]): - rej_weights = array_ops.identity(rej_weights) - self.evaluate(rej_weights) - - def testRejection4D(self): - num_samples = int(1e5) # Chosen for a small min detectable discrepancy - det_bounds = np.array([0.0], dtype=np.float32) - exact_volumes = [four_by_four_volume()] - (rej_weights, - rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( - det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) - # shape of rej_weights: [num_samples, 1, 4, 4] - chk1 = st.assert_true_mean_equal_by_dkwm( - rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, - false_fail_rate=1e-6) - chk2 = check_ops.assert_less( - st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, low=0., high=rej_proposal_volume, - false_fail_rate=1e-6, false_pass_rate=1e-6), - # Going for about a 10% relative error - 1.1) - with ops.control_dependencies([chk1, chk2]): - rej_weights = array_ops.identity(rej_weights) - self.evaluate(rej_weights) - - def testVolumeEstimation2D(self): - # Test that the confidence intervals produced by - # corr.compte_true_volumes are sound, in the sense of containing - # the exact volume. - num_samples = int(1e5) # Chosen by symmetry with testRejection2D - det_bounds = np.array( - [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) - volume_bounds = corr.compute_true_volumes( - det_bounds, 2, num_samples, error_rate=1e-6, seed=47) - exact_volumes = two_by_two_volume(det_bounds) - for det, volume in zip(det_bounds, exact_volumes): - computed_low, computed_high = volume_bounds[det] - self.assertLess(computed_low, volume) - self.assertGreater(computed_high, volume) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index bb9b8043b2233b2109f51b5dde188d088fdb0d39..3ba1c3a66517887dba204e081d3f31a95d86e295 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -65,13 +65,14 @@ class Autoregressive(distribution_lib.Distribution): ``` where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn` - constructs a `tf.distributions.Distribution`-like instance, and `x0` is a + constructs a `tfp.distributions.Distribution`-like instance, and `x0` is a fixed initializing `Tensor`. #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions def normal_fn(self, event_size): n = event_size * (event_size + 1) / 2 @@ -127,7 +128,7 @@ class Autoregressive(distribution_lib.Distribution): Args: distribution_fn: Python `callable` which constructs a - `tf.distributions.Distribution`-like instance from a `Tensor` (e.g., + `tfp.distributions.Distribution`-like instance from a `Tensor` (e.g., `sample0`). The function must respect the "autoregressive property", i.e., there exists a permutation of event such that each coordinate is a diffeomorphic function of on preceding coordinates. diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index 519077bc9ab1063a1135486cfae34656f3f68157..612376efb7f43b0dfcd3ffeb5437f2a419f66f4d 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -45,7 +45,8 @@ class BatchReshape(distribution_lib.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions dtype = np.float32 dims = 2 diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 296e66f2b24fecf2142066727b5b12ee5cbd0379..3b3d8ee6f2dc595983fd5e283d0435e8a227f2ba 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -61,8 +61,8 @@ class MaskedAutoregressiveFlow(bijector.Bijector): `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves this property by zeroing out weights in its `masked_dense` layers. - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression" + In the `tfp` framework, a "normalizing flow" is implemented as a + `tfp.bijectors.Bijector`. The `forward` "autoregression" is implemented using a `tf.while_loop` and a deep neural network (DNN) with masked weights such that the autoregressive property is automatically met in the `inverse`. @@ -126,8 +126,9 @@ class MaskedAutoregressiveFlow(bijector.Bijector): #### Examples ```python - tfd = tf.contrib.distributions - tfb = tfd.bijectors + import tensorflow_probability as tfp + tfd = tfp.distributions + tfb = tfp.bijectors dims = 5 diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index f182a1adcbb6b11af2376cd271f903d50e50f1a0..178c3c94bfd319e3182a60054ea55c4ccaf01607 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -41,9 +41,10 @@ class Permute(bijector.Bijector): """Permutes the rightmost dimension of a `Tensor`. ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfb = tfp.bijectors - reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + reverse = tfb.Permute(permutation=[2, 1, 0]) reverse.forward([-1., 0., 1.]) # ==> [1., 0., -1] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 773ae2446118051a61636bc21de6b81dfacda746..0bcb08cdea7142b82af3116245306a11773ef93c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -90,8 +90,9 @@ class RealNVP(bijector.Bijector): #### Example Use ```python - tfd = tf.contrib.distributions - tfb = tfd.bijectors + import tensorflow_probability as tfp + tfd = tfp.distributions + tfb = tfp.bijectors # A common choice for a normalizing flow is to use a Gaussian for the base # distribution. (However, any continuous distribution would work.) E.g., diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9..71ac29038fc12e7d046df8624c6e3e5bb97d3d8f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -80,9 +80,10 @@ class Reshape(bijector.Bijector): Example usage: ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfb = tfp.bijectors - r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + r = tfb.Reshape(event_shape_out=[1, -1]) r.forward([3., 4.]) # shape [2] # ==> [[3., 4.]] # shape [1, 2] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py index 6fbe8665781211ca803feb8bf5a8c04fb0b969e8..0a6d690b65cdfa5944c737cecc5caae2adebd7dd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -42,7 +42,10 @@ class ScaleTriL(chain.Chain): #### Examples ```python - tfb = tf.contrib.distributions.bijectors + import tensorflow_probability as tfp + tfd = tfp.distributions + tfb = tfp.bijectors + b = tfb.ScaleTriL( diag_bijector=tfb.Exp(), diag_shift=None) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index cb5223b0557080e10bf24c3e1cb432f15fd5e7e3..c461833b9ae91d8c3525b4099580a8f0caceadae 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -63,7 +63,8 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Define a single scalar Cauchy distribution. dist = tfd.Cauchy(loc=0., scale=3.) diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index affc64a14f6fe9ae6e08ceff2298bc99ee7caa43..507c5d36794df75c09d2293ed66111c17c06af37 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -198,8 +198,11 @@ class Deterministic(_BaseDeterministic): #### Examples ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + # Initialize a single Deterministic supported at zero. - constant = tf.contrib.distributions.Deterministic(0.) + constant = tfd.Deterministic(0.) constant.prob(0.) ==> 1. constant.prob(2.) @@ -208,7 +211,7 @@ class Deterministic(_BaseDeterministic): # Initialize a [2, 2] batch of scalar constants. loc = [[0., 1.], [2., 3.]] x = [[0., 1.1], [1.99, 3.]] - constant = tf.contrib.distributions.Deterministic(loc) + constant = tfd.Deterministic(loc) constant.prob(x) ==> [[1., 0.], [0., 1.]] ``` @@ -310,7 +313,8 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. constant = tfd.Deterministic([0., 2.]) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b..4b50df5b481513aa964e680dbb60cc5c641410aa 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -63,7 +63,8 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Define a single scalar Gumbel distribution. dist = tfd.Gumbel(loc=0., scale=3.) diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index b02c4031069191592b8acc1a90313450f98af6d7..f1216370869f1e7e3168acc959f14eb4bd874984 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -66,15 +66,18 @@ class HalfNormal(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + # Define a single scalar HalfNormal distribution. - dist = tf.contrib.distributions.HalfNormal(scale=3.0) + dist = tfd.HalfNormal(scale=3.0) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued HalfNormals. # The first has scale 11.0, the second 22.0 - dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + dist = tfd.HalfNormal(scale=[11.0, 22.0]) # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, # returning a length two tensor. diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 0672702b96c1eb81c176774554df3f5922a0319e..e1cfff3c66a2bcbc98af8a257dbdea2d916270e2 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -70,7 +70,8 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Make independent distribution from a 2-batch Normal. ind = tfd.Independent( diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 70d050d7a647b38928ddb1c788db0e6957ac0f03..452628257ea96713453bf2aa32b5baa9d6d0cb86 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -89,7 +89,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 02e3bad51ee48188acf83cb09359861c9e6932c7..21c9b5a35448e4195a278e0e31ca9657df49f618 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -61,7 +61,8 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Define a single scalar Logistic distribution. dist = tfd.Logistic(loc=0., scale=3.) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 3b7114ef067c0aaede23fff04c40d1dc6e830f1c..52b67f2c54c89eaed6c500d32f79865453030644 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -50,7 +50,9 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions + mix = 0.3 bimix_gauss = tfd.Mixture( cat=tfd.Categorical(probs=[mix, 1.-mix]), diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 8ffee940d03c9a5204f2ac6f7acd9ea482adae1a..f4d394ff29f072907a019afb52bd8dc5d244e955 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -44,7 +44,8 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions ### Create a mixture of two scalar Gaussians: @@ -113,12 +114,12 @@ class MixtureSameFamily(distribution.Distribution): """Construct a `MixtureSameFamily` distribution. Args: - mixture_distribution: `tf.distributions.Categorical`-like instance. + mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. - components_distribution: `tf.distributions.Distribution`-like instance. + components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index cd0c282ba6cebf784261a4e821f36ce4eed98fe0..0b5b76be9231e09d3d6937ee889e73a1db4f6f03 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -85,7 +85,8 @@ class MultivariateNormalDiag( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 2-variate Gaussian. mvn = tfd.MultivariateNormalDiag( diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 74d9d04fc702a90a5fc5a31f554abe257dd2860d..80546083d3f908f97ddc3fb9d9d130f3609c3d56 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -87,7 +87,8 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097..bcb4937980020f622659103d0315275439206255 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -73,7 +73,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f..8fdc99824b6dca452b38b706e03f964d05bfaffc 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -91,7 +91,8 @@ class MultivariateNormalLinearOperator( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index c6a23e4336fffbf7b61490dd3468bc71c7f421cc..c21f70fc3b36fdd2ff1d293712952ab932138edd 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -77,13 +77,14 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `tf.contrib.distributions.matrix_diag_transform()` and/or - `tf.contrib.distributions.fill_triangular()` + `tfp.distributions.matrix_diag_transform()` and/or + `tfp.distributions.fill_triangular()` #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 7a7ad1be35b80ff0f000181ea0778ab282a8220f..85683e3233d659e5b3470b96b610342dbeee2e17 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -220,7 +220,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 18a0f754e6e618f240db109f593a80dec57e200b..134658deabe8d69b5747cd32879f92fbbaab1b5a 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -196,8 +196,9 @@ class QuantizedDistribution(distributions.Distribution): parameter determining the unnormalized probability of that component. ```python - tfd = tf.contrib.distributions - tfb = tfd.bijectors + import tensorflow_probability as tfp + tfd = tfp.distributions + tfb = tfp.bijectors net = wavenet(inputs) loc, unconstrained_scale, logits = tf.split(net, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7..4b520b912e74313dce00ce71c7da093728d36075 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -124,7 +124,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `tf.distributions.Normal(0., 1.)`. + Default is `tfp.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index c25e8c51d7705b641699fb05623c7b0fb4950e1b..af22f4843a00938b1a6742c86f2346055c15b817 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -30,27 +30,27 @@ is some expected constant. Suppose the support of P is the interval `[0, 1]`. Then you might do this: ```python -tfd = tf.contrib.distributions - -expected_mean = ... -num_samples = 5000 -samples = ... draw 5000 samples from P - -# Check that the mean looks right -check1 = tfd.assert_true_mean_equal_by_dkwm( - samples, low=0., high=1., expected=expected_mean, - false_fail_rate=1e-6) - -# Check that the difference in means detectable with 5000 samples is -# small enough -check2 = tf.assert_less( - tfd.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, low=0., high=1.0, - false_fail_rate=1e-6, false_pass_rate=1e-6), - 0.01) - -# Be sure to execute both assertion ops -sess.run([check1, check2]) + from tensorflow_probability.python.distributions.internal import statistical_testing + + expected_mean = ... + num_samples = 5000 + samples = ... draw 5000 samples from P + + # Check that the mean looks right + check1 = statistical_testing.assert_true_mean_equal_by_dkwm( + samples, low=0., high=1., expected=expected_mean, + false_fail_rate=1e-6) + + # Check that the difference in means detectable with 5000 samples is + # small enough + check2 = tf.assert_less( + statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=1.0, + false_fail_rate=1e-6, false_pass_rate=1e-6), + 0.01) + + # Be sure to execute both assertion ops + sess.run([check1, check2]) ``` The second assertion is an instance of experiment design. It's a diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index ece03fe4aab3cc3046e0958d883ca9388517b94b..a3d178357b79b9d0d15c738603d5019321eff112 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered -from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -36,6 +35,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.linalg import linear_operator_addition as linop_add_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib @@ -300,7 +300,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.], # another with mix_loc=[1]. In both cases, `K=2` and the affine diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a..36cbd71f8b33a3a00ace2ed5ebd8447c940638d6 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -90,7 +90,8 @@ class VectorExponentialDiag( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 9a47b4855763a25b484ad04a3415d191f19256f7..fd5bf9ecc722ea5247f63a76d8f93f8072d6a029 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -108,7 +108,8 @@ class VectorExponentialLinearOperator( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index e68ddc569c95ff63760b4b2f6d7a92f17240a558..8cd4e128c7a835e6cc991e3456ff65e4603a12e6 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -102,7 +102,8 @@ class VectorLaplaceDiag( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 2-variate VectorLaplace. vla = tfd.VectorLaplaceDiag( diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 3923161a332a77e4eaab8d65d96fd8c278c872ec..67d2ccd28d6d487acccf5e894618457194dce913 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -110,7 +110,8 @@ class VectorLaplaceLinearOperator( #### Examples ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 49ffff24caec8d6c525f65f06796d10548d5ec40..da57d0cb55d72d00d213c0d131a13c702a22cd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -152,7 +152,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `tf.distributions.Normal(loc=0., scale=1.)`. + `tfp.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index f289b39e51aff36780541a0545ed9e6cfe21dd4e..bad91a08447f5ab443a330ddb1411c80938cf823 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -92,7 +92,8 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - tfd = tf.contrib.distributions + import tensorflow_probability as tfp + tfd = tfp.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 49b9de0ab508f5db090bb1349f596da1b2a71b49..ee2fc58864d4ac528ebae3d681d2e4922fb60771 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -480,11 +480,14 @@ class WishartCholesky(_WishartLinearOperator): #### Examples ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + # Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5 # degrees-of-freedom.(*) df = 5 chol_scale = tf.cholesky(...) # Shape is [3, 3]. - dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale) + dist = tfd.WishartCholesky(df=df, scale=chol_scale) # Evaluate this on an observation in R^3, returning a scalar. x = ... # A 3x3 positive definite matrix. @@ -498,14 +501,14 @@ class WishartCholesky(_WishartLinearOperator): # Initialize two 3x3 Wisharts with Cholesky factored scale matrices. df = [5, 4] chol_scale = tf.cholesky(...) # Shape is [2, 3, 3]. - dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale) + dist = tfd.WishartCholesky(df=df, scale=chol_scale) # Evaluate this on four observations. x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3]. dist.prob(x) # Shape is [2, 2]. # (*) - To efficiently create a trainable covariance matrix, see the example - # in tf.contrib.distributions.matrix_diag_transform. + # in tfp.distributions.matrix_diag_transform. ``` """ @@ -604,11 +607,14 @@ class WishartFull(_WishartLinearOperator): #### Examples ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + # Initialize a single 3x3 Wishart with Full factored scale matrix and 5 # degrees-of-freedom.(*) df = 5 scale = ... # Shape is [3, 3]; positive definite. - dist = tf.contrib.distributions.WishartFull(df=df, scale=scale) + dist = tfd.WishartFull(df=df, scale=scale) # Evaluate this on an observation in R^3, returning a scalar. x = ... # A 3x3 positive definite matrix. @@ -622,14 +628,14 @@ class WishartFull(_WishartLinearOperator): # Initialize two 3x3 Wisharts with Full factored scale matrices. df = [5, 4] scale = ... # Shape is [2, 3, 3]. - dist = tf.contrib.distributions.WishartFull(df=df, scale=scale) + dist = tfd.WishartFull(df=df, scale=scale) # Evaluate this on four observations. x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3]; xi is positive definite. dist.prob(x) # Shape is [2, 2]. # (*) - To efficiently create a trainable covariance matrix, see the example - # in tf.contrib.distributions.matrix_diag_transform. + # in tfd.matrix_diag_transform. ``` """ diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 86d203452e24d6d73f3ebb17b989867905a61382..4bd2769e879fb0bfc30a2de73d1fcf65d7b3bf19 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -44,7 +44,6 @@ Installation instructions at https://www.tensorflow.org/install/ For an introduction to eager execution in TensorFlow, see: -- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md)) -- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) -- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) -- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) +- [User Guide](https://www.tensorflow.org/guide/eager) ([source](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/index.md)) +- Notebook: [Basic Usage](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb) +- Notebook: [Automatic differentiation and gradient tape](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 84517b57c7d0af56ba7724d18e78f38041ebe773..33a1d572a20e68479d3ec1147d4892449e7beb8a 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -14,6 +14,7 @@ py_library( ":datasets", ":metrics", ":network", + ":parameter_server", ":remote", ":saver", "//tensorflow/python:framework_ops", @@ -97,6 +98,18 @@ py_library( ], ) +py_library( + name = "parameter_server", + srcs = ["parameter_server.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) + cuda_py_test( name = "saver_test", srcs = ["saver_test.py"], @@ -241,6 +254,7 @@ py_test( srcs = ["remote_test.py"], srcs_version = "PY2AND3", deps = [ + ":parameter_server", ":remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 135095a97980da8988b976948fb18492526e390c..3aed121233be1268531495a2fa83fd323412e1fd 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops @@ -54,7 +54,7 @@ class Iterator(iterator_ops.EagerIterator): """ if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access raise TypeError( - "`tf.contrib.data.prefetch_to_device()` is not compatible with " + "`tf.data.experimental.prefetch_to_device()` is not compatible with " "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " "over the dataset instead.") diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a753d77580758af9de8410de4a08f7ea278c4c79..6a508fc6ba98740c4d441a064dc8a3e2b321f585 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -24,11 +24,11 @@ import time import numpy as np from tensorflow.contrib import lookup -from tensorflow.contrib.data.python.ops import prefetching_ops -from tensorflow.contrib.data.python.ops import threadpool -from tensorflow.contrib.data.python.ops import unique from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset +from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.data.experimental.ops import threadpool +from tensorflow.python.data.experimental.ops import unique from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 6f02c90368d966b8cf8d0dee09f9d2a5013c90c1..97c299a911c9180bf69faa0fa46527e80eada790 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,6 +6,7 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ + "//tensorflow/contrib/eager/python/examples/densenet", "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/l2hmc", "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index c61ec2dbae60a782c0e6589701554b045dcb92ae..d64c8eb9ce122fa277567b2fbc632abfbc72df64 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "mnist", diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py index 955747988536bd21d52df66a35af4aa31b3f7688..1c925e455b9ce9f52b9e1a32fe0b0110ae8f4f41 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -37,26 +37,43 @@ def get_default_hparams(): n_warmup_iters=3) +def step(dynamics, optimizer, samples): + loss, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=l2hmc.compute_loss) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + return loss, samples + + +# To be defunnable, the function cannot return an Operation, so the above +# function is used for defun or eager, and this function is used in graph to be +# able to run the gradient updates. +def graph_step(dynamics, optimizer, samples): + loss, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=l2hmc.compute_loss) + train_op = optimizer.apply_gradients(zip(grads, dynamics.variables)) + + return train_op, loss, samples + + def warmup(dynamics, optimizer, n_iters=1, n_samples=200, - loss_fn=l2hmc.compute_loss): + step_fn=step): """Warmup optimization to reduce overhead.""" samples = tf.random_normal( shape=[n_samples, dynamics.x_dim], dtype=tf.float32) for _ in range(n_iters): - _, grads, samples, _ = l2hmc.loss_and_grads( - dynamics, samples, loss_fn=loss_fn) - optimizer.apply_gradients(zip(grads, dynamics.variables)) + _, samples = step_fn(dynamics, optimizer, samples) def fit(dynamics, samples, optimizer, - loss_fn=l2hmc.compute_loss, + step_fn=step, n_iters=5000, verbose=True, logdir=None): @@ -66,9 +83,7 @@ def fit(dynamics, summary_writer = tf.contrib.summary.create_file_writer(logdir) for i in range(n_iters): - loss, grads, samples, _ = l2hmc.loss_and_grads( - dynamics, samples, loss_fn=loss_fn) - optimizer.apply_gradients(zip(grads, dynamics.variables)) + loss, samples = step_fn(dynamics, optimizer, samples) if verbose: print("Iteration %d: loss %.4f" % (i, loss)) @@ -130,51 +145,48 @@ class L2hmcBenchmark(tf.test.Benchmark): """Benchmark Graph performance.""" hparams = get_default_hparams() - tf.reset_default_graph() - with tf.Graph().as_default(): - energy_fn, _, _ = l2hmc.get_scg_energy_fn() - dynamics = l2hmc.Dynamics( - x_dim=hparams.x_dim, - minus_loglikelihood_fn=energy_fn, - n_steps=hparams.n_steps, - eps=hparams.eps) - x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) - loss, x_out, _ = l2hmc.compute_loss(dynamics, x) - - global_step = tf.Variable(0., name="global_step", trainable=False) - learning_rate = tf.train.exponential_decay( - hparams.learning_rate, global_step, 1000, 0.96, staircase=True) - optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) - train_op = optimizer.minimize(loss, global_step=global_step) - - # Single thread; fairer comparison against eager - session_conf = tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) - - with tf.Session(config=session_conf) as sess: - sess.run(tf.global_variables_initializer()) - - # Warmup to reduce initialization effect when timing - samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) - for _ in range(hparams.n_warmup_iters): - _, _, _, _ = sess.run( - [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) - - # Training - start_time = time.time() - for i in range(hparams.n_iters): - samples, loss_np, _, _ = sess.run( - [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) - print("Iteration %d: loss %.4f" % (i, loss_np)) - wall_time = time.time() - start_time - examples_per_sec = hparams.n_samples / wall_time - - self.report_benchmark( - name="graph_train_%s" % ("gpu" - if tf.test.is_gpu_available() else "cpu"), - iters=hparams.n_iters, - extras={"examples_per_sec": examples_per_sec}, - wall_time=wall_time) + tf.enable_resource_variables() + for sample_size in [10, 25, 50, 100, 200]: + hparams.n_samples = sample_size + tf.reset_default_graph() + with tf.Graph().as_default(): + energy_fn, _, _ = l2hmc.get_scg_energy_fn() + x = tf.random_normal([hparams.n_samples, hparams.x_dim], + dtype=tf.float32) + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + minus_loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + loss, _, _ = l2hmc.compute_loss(dynamics, x) + + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + train_op, loss, _ = graph_step(dynamics, optimizer, x) + + # Single thread; fairer comparison against eager + session_conf = tf.ConfigProto(inter_op_parallelism_threads=1) + + with tf.Session(config=session_conf) as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + for _ in range(hparams.n_warmup_iters): + _, _ = sess.run([train_op, loss]) + + # Training + start_time = time.time() + for i in range(hparams.n_iters): + _, loss_np = sess.run([train_op, loss]) + print("Iteration %d: loss %.4f" % (i, loss_np)) + wall_time = (time.time() - start_time) / hparams.n_iters + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s_%d" % + ("gpu" if tf.test.is_gpu_available() else "cpu", sample_size), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) def benchmark_eager(self): self._benchmark_eager() @@ -186,32 +198,44 @@ class L2hmcBenchmark(tf.test.Benchmark): """Benchmark Eager performance.""" hparams = get_default_hparams() - energy_fn, _, _ = l2hmc.get_scg_energy_fn() - dynamics = l2hmc.Dynamics( - x_dim=hparams.x_dim, - minus_loglikelihood_fn=energy_fn, - n_steps=hparams.n_steps, - eps=hparams.eps) - optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) - loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss - - # Warmup to reduce initialization effect when timing - warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) - - # Training - samples = tf.random_normal( - shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) - start_time = time.time() - fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters) - wall_time = time.time() - start_time - examples_per_sec = hparams.n_samples / wall_time - - self.report_benchmark( - name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else - "cpu", "_defun" if defun else ""), - iters=hparams.n_iters, - extras={"examples_per_sec": examples_per_sec}, - wall_time=wall_time) + for sample_size in [10, 25, 50, 100, 200]: + hparams.n_samples = sample_size + energy_fn, _, _ = l2hmc.get_scg_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + minus_loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + step_fn = tfe.defun(step) if defun else step + + # Warmup to reduce initialization effect when timing + warmup( + dynamics, + optimizer, + n_iters=hparams.n_warmup_iters, + n_samples=hparams.n_samples, + step_fn=step_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + step_fn=step_fn, + n_iters=hparams.n_iters) + wall_time = (time.time() - start_time) / hparams.n_iters + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s_%d" % + ("gpu" if tf.test.is_gpu_available() else "cpu", + "_defun" if defun else "", sample_size), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) del dynamics diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 2f6cfdf31e852d5d69a7a87980c9a441da504cf2..74ce9e84f013d79b3a33ffa79993980b561e366d 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "linear_regression", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 8fae622e12864ddeee0cedd3cf99be8ea5e4bc48..446e3401184ded6bc34ed64cdd720e29a2851855 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -65,7 +65,7 @@ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 551c76b0df71c88919df9cd6d81b4176b23b0ba3..f3bb978875e226f58d6a00e09154191673a97415 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -51,7 +51,9 @@ def random_batch(batch_size): class ResNet50GraphTest(tf.test.TestCase): def testApply(self): - batch_size = 64 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 8 with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) @@ -63,7 +65,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run(init) np_images, _ = random_batch(batch_size) out = sess.run(predictions, feed_dict={images: np_images}) - self.assertAllEqual([64, 1000], out.shape) + self.assertAllEqual([batch_size, 1000], out.shape) def testTrainWithSummary(self): with tf.Graph().as_default(): @@ -87,7 +89,9 @@ class ResNet50GraphTest(tf.test.TestCase): init = tf.global_variables_initializer() self.assertEqual(321, len(tf.global_variables())) - batch_size = 32 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 2 with tf.Session() as sess: sess.run(init) sess.run(tf.contrib.summary.summary_writer_initializer_op()) diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py index 34a9984b0ecc527ad1991c28146246b716e96c98..d85188de030af2bbab1c141b5c090371248110b9 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -169,11 +169,11 @@ class ImageNetInput(object): # Read the data from disk in parallel dataset = dataset.apply( - tf.contrib.data.parallel_interleave( + tf.data.experimental.parallel_interleave( fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True)) if self.cache: dataset = dataset.cache().apply( - tf.contrib.data.shuffle_and_repeat(1024 * 16)) + tf.data.experimental.shuffle_and_repeat(1024 * 16)) else: dataset = dataset.shuffle(1024) @@ -188,9 +188,11 @@ class ImageNetInput(object): # batch size. As long as this validation is done with consistent batch size, # exactly the same images will be used. dataset = dataset.apply( - tf.contrib.data.map_and_batch( - self.dataset_parser, batch_size=batch_size, - num_parallel_batches=self.num_cores, drop_remainder=True)) + tf.data.experimental.map_and_batch( + self.dataset_parser, + batch_size=batch_size, + num_parallel_batches=self.num_cores, + drop_remainder=True)) # Transpose for performance on TPU if self.transpose_input: diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 6a921e19978fdf6e3c20974b2c349bd6923b5782..4f4cc3af6f1d5c626b3e2ea7939ecad0ee2d41f1 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -50,6 +50,9 @@ class RevNetTest(tf.test.TestCase): # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 + # Reduce the batch size for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + config.batch_size = 2 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index f83eb5c476ed9f45d70849a0de6c0f20973682a5..d500b632ebb97fd12ded3a215b0f1a686194874f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "rnn_colorbot", diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index 4b4792cd49bf8bd4ad46a0371ef0d2f8a07ddd1c..2cc2fcbfeb21ee6218d7912d9a93ea2f7b2ea226 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "rnn_ptb", diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9e7b027ed68935f2bc0ddbd27a1821a663850d --- /dev/null +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -0,0 +1,289 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""EXPERIMENTAL utilities for parameter server training with eager execution. + +Note: this should eventually be merged with the distribution strategy for +ParameterServer. +""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import time + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training.checkpointable import base as checkpointable + + +def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): + """Creates a variable handle with information to do shape inference.""" + container = ops.get_default_graph()._container # pylint: disable=protected-access + if container is None: + container = "" + handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + if graph_mode: + return handle + + with context.graph_mode(), ops.Graph().as_default() as graph: + h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + + # Tensor._handle_data contains information for the shape-inference code to + # know the shape and dtype of the variable pointed to by a handle. Since + # shape inference doesn't run in eager mode we copy this data here for when + # the handle is captured by an eager mode function. + # pylint: disable=protected-access + if ops._USE_C_SHAPES: + handle._handle_data = resource_variable_ops.get_resource_handle_data(h) + else: + if h._handle_data is None: + ops.set_shape_and_handle_data_for_outputs(h.op) + handle._handle_data = h._handle_data + # pylint: enable=protected-access + # Clean up op->graph->op reference cycles. + ops.dismantle_graph(graph) + return handle + + +class SharedVariable(resource_variable_ops.ResourceVariable): + """Experimental Variable designed for parameter server training. + + A SharedVariable has a name and two instances of SharedVariable with the + same name will have the same value, even if they are in different Sessions, + as long as they are placed on the same device. + + The storage associated with SharedVariables is also not deleted when they go + out of scope. + """ + + def __init__(self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=True, + name=None, + dtype=None, + constraint=None, + initialize=True, + **unused_kwargs): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, automatically watches this variable on GradientTape + whenever it's used. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + initialize: if True, runs initialization in eager execution; leaves the + variable uninitialized otherwise. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + """ + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if isinstance(initial_value, ops.Tensor) and hasattr( + initial_value, "graph") and initial_value.graph.building_function: + raise ValueError("Tensor-typed variable initializers must either be " + "wrapped in an init_scope or callable " + "(e.g., `tf.Variable(lambda : " + "tf.truncated_normal([10, 40]))`) when building " + "functions. Please file a feature request if this " + "restriction inconveniences you.") + + if constraint is not None and not callable(constraint): + raise ValueError("The `constraint` argument must be a callable.") + + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + self._trainable = trainable + self._save_slice_info = None + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + handle_name = ops._name_from_scope_name(name) + shared_name = handle_name + if init_from_fn: + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + if self._in_graph_mode: + with ops.name_scope("Initializer"), ops.device(None): + initial_value = ops.convert_to_tensor( + initial_value(), name="initial_value", dtype=dtype) + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + self._shape = initial_value.get_shape() + else: + initial_value = initial_value() + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=False) + self._shape = initial_value.get_shape() + # pylint: enable=protected-access + + # Or get the initial value from a Tensor or Python object. + else: + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + # pylint: disable=protected-access + if (self._in_graph_mode and initial_value is not None and + initial_value.op._get_control_flow_context() is not None): + raise ValueError( + "Initializer for variable %s is from inside a control-flow " + "construct, such as a loop or conditional. When creating a " + "variable inside a loop or conditional, use a lambda as the " + "initializer." % name) + # pylint: enable=protected-access + self._handle = _eager_safe_variable_handle( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + self._shape = initial_value.get_shape() + + self._unique_id = shared_name + self._initial_value = initial_value if self._in_graph_mode else None + self._handle_name = handle_name + ":0" + self._dtype = initial_value.dtype.base_dtype + self._constraint = constraint + + if self._in_graph_mode: + with ops.name_scope("IsInitialized"): + self._is_initialized_op = ( + resource_variable_ops.var_is_initialized_op(self._handle)) + if initial_value is not None: + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + self._initializer_op = ( + resource_variable_ops.assign_variable_op( + self._handle, + self._try_guard_against_uninitialized_dependencies( + initial_value), + name=n)) + with ops.name_scope("Read"), ops.colocate_with(self._handle): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(self._handle.device): + value = self._read_variable_op() + self._graph_element = value + self._cached_value = None + else: + if initialize: + resource_variable_ops.assign_variable_op(self._handle, + initial_value) + self._is_initialized_op = None + self._initializer_op = None + self._graph_element = None + self._cached_value = None + + self._handle_deleter = None + self._cached_shape_as_list = None + + +@contextlib.contextmanager +def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks): + """Strategy to use parameter servers in eager. + + Creates SharedVariable objects for variables created in this scope. These + SharedVariable objects will be placed round-robin on the parameter servers + specified by the ps_job_name and num_ps_tasks arguments. + + To use parameter servers you need only to wrap your model initialization in + this scope: + + ``` + with tf.contrib.eager.parameter_server_scope( + is_chief, ps_job_name, num_ps_tasks): + my_model = tf.keras.Sequential([...]) # Or + input = tf.keras.Input(...) + .... + my_model = tf.keras.Model(input, output) + my_model.compile(...) + # or other usages of the model. + ``` + + Args: + is_chief: Boolean. Whether this worker is responsible for initializing + variables. + ps_job_name: The name of the ps job in this cluster. + num_ps_tasks: The number of ps tasks to use. + + Yields: + a context manager. + """ + # Note: capturing in a list to allow assignment. + ps_index = [0] + + def variable_creator_scope(unused_next_creator, **kwargs): + kwargs["initialize"] = is_chief + with ops.device( + "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)): + ps_index[0] += 1 + v = SharedVariable(**kwargs) + if not is_chief: + while not resource_variable_ops.var_is_initialized_op(v.handle): + time.sleep(10) + return v + + with variable_scope.variable_creator_scope(variable_creator_scope): + yield diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 13029db975bcbf8a6b31ba3c11d4c2b08edfdb6f..7aa4b598b833c3419af501b49f1509d18f3530d5 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.contrib.eager.python import parameter_server from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -33,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -45,8 +47,9 @@ def run_sync_and_async(f): @functools.wraps(f) def decorator(self, *args, **kwargs): - with context.execution_mode(context.ASYNC): - f(self, *args, **kwargs) + # TODO(b/117110239): Re-enable. + # with context.execution_mode(context.ASYNC): + # f(self, *args, **kwargs) with context.execution_mode(context.SYNC): f(self, *args, **kwargs) @@ -120,6 +123,24 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + def testParameterServer(self): + with parameter_server.parameter_server_scope( + is_chief=True, ps_job_name=JOB_NAME, num_ps_tasks=3): + v0 = variables.Variable([1.0], name="v0") + v1 = variables.Variable([2.0], name="v1") + v0.assign(v0 * v1) + self.assertAllEqual(v0.read_value(), [2.0]) + self.assertAllEqual(v0.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + self.assertAllEqual(v1.device, + "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME) + v1.assign_add(v1) + # Simulate aliasing another variable of the same name as v1 + with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + v1_replica = parameter_server.SharedVariable( + [1.0], name="v1", initialize=False) + self.assertAllEqual(v1_replica.read_value(), [4.0]) + @run_sync_and_async def testSimpleWeightRead(self): """Basic remote eager weight read.""" diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6db311d52de61359995087fb5ca3d5461f74c4c1..1ea00fb7f3c6a19824abc8eb80726bb3bba183aa 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -132,21 +132,11 @@ py_library( srcs = ["python/estimator/dnn_with_layer_annotations.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:variable_scope", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:optimizers", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:utils", ], ) @@ -162,22 +152,13 @@ py_test( ], deps = [ ":dnn_with_layer_annotations", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:dnn_testing_utils", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:pandas_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", "@six_archive//:six", ], ) @@ -283,9 +264,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:exporter", ], ) @@ -297,7 +276,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":exporter", - "//tensorflow/python:platform", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:exporter", ], @@ -502,7 +481,6 @@ py_library( "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", - "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) @@ -557,13 +535,10 @@ py_library( srcs = ["python/estimator/saved_model_estimator.py"], deps = [ ":export", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:export", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model", ], ) @@ -578,16 +553,7 @@ py_test( deps = [ ":export", ":saved_model_estimator", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:export_output", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 78914ecacaf79fd25b33d4159601ab49d2b74c96..419609b1af7b19dc9cf2960e96e71d54d8eb0c9b 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -76,7 +76,7 @@ _allowed_symbols = [ 'stop_if_no_decrease_hook', 'build_raw_supervised_input_receiver_fn', 'build_supervised_input_receiver_fn_from_input_fn', - 'SavedModelEstimator' + 'SavedModelEstimator', 'DNNClassifierWithLayerAnnotations', 'DNNRegressorWithLayerAnnotations', ] diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 7ed77bcce6f00ed13e9952951800f1017d582f19..b131ed4f12a01a0087390b5bb65f3ac2d5aec657 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +from tensorflow.python.estimator.canned import head as head_lib def _validate_input_fn_and_repeat_dataset(train_input_fn): @@ -33,7 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn): return _input_fn -class _BoostedTreesEstimator(estimator.Estimator): +def _is_classification_head(head): + """Infers if the head is a classification head.""" + # Check using all classification heads defined in canned/head.py. However, it + # is not a complete list - it does not check for other classification heads + # not defined in the head library. + # pylint: disable=protected-access + return isinstance(head, + (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss, + head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss)) + # pylint: enable=protected-access + + +class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access """An Estimator for Tensorflow Boosted Trees models.""" def __init__(self, @@ -62,7 +75,7 @@ class _BoostedTreesEstimator(estimator.Estimator): layer. head: the `Head` instance defined for Estimator. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -96,9 +109,12 @@ class _BoostedTreesEstimator(estimator.Estimator): negative gain). For pre and post pruning, you MUST provide tree_complexity >0. + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. """ - # pylint:disable=protected-access # HParams for the model. + # pylint: disable=protected-access tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, tree_complexity, min_node_weight, center_bias, pruning_mode) @@ -115,8 +131,14 @@ class _BoostedTreesEstimator(estimator.Estimator): config=config) super(_BoostedTreesEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) - # pylint:enable=protected-access + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_columns=feature_columns, + head=head, + center_bias=center_bias, + is_classification=_is_classification_head(head)) + # pylint: enable=protected-access def boosted_trees_classifier_train_in_memory( @@ -177,7 +199,7 @@ def boosted_trees_classifier_train_in_memory( the model. All items in the set should be instances of classes derived from `FeatureColumn`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. n_classes: number of label classes. Default is binary classification. Multiclass support is not yet implemented. @@ -323,7 +345,7 @@ def boosted_trees_regressor_train_in_memory( the model. All items in the set should be instances of classes derived from `FeatureColumn`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. label_dimension: Number of regression targets per example. Multi-dimensional support is not yet implemented. diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index b1581f37509b5dc2bec98942e88c024905f25d93..e23d9c0fc4c32ce0ce23dcf4be518577795dd35f 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -360,5 +360,79 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [pred['predictions'] for pred in predictions]) +class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._head = canned_boosted_trees._create_regression_head(label_dimension=1) + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) + } + + def testContribEstimatorThatDFCIsInPredictions(self): + # pylint:disable=protected-access + head = canned_boosted_trees._create_regression_head(label_dimension=1) + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + head=head, + n_trees=1, + max_depth=5, + center_bias=True) + # pylint:enable=protected-access + + num_steps = 100 + # Train for a few steps. Validate debug outputs in prediction dicts. + est.train(train_input_fn, steps=num_steps) + debug_predictions = est.experimental_predict_with_explanations( + predict_input_fn) + biases, dfcs = zip(*[(pred['bias'], pred['dfc']) + for pred in debug_predictions]) + self.assertAllClose([1.8] * 5, biases) + self.assertAllClose(({ + 0: -0.070499420166015625, + 1: -0.095000028610229492, + 2: 0.0 + }, { + 0: -0.53763031959533691, + 1: 0.063333392143249512, + 2: 0.0 + }, { + 0: -0.51756942272186279, + 1: -0.095000028610229492, + 2: 0.0 + }, { + 0: 0.1563495397567749, + 1: 0.063333392143249512, + 2: 0.0 + }, { + 0: 0.96934974193572998, + 1: 0.063333392143249512, + 2: 0.0 + }), dfcs) + + # Assert sum(dfcs) + bias == predictions. + expected_predictions = [[1.6345005], [1.32570302], [1.1874305], + [2.01968288], [2.83268309]] + predictions = [ + [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases) + ] + self.assertAllClose(expected_predictions, predictions) + + # Test when user doesn't include bias or dfc in predict_keys. + debug_predictions = est.experimental_predict_with_explanations( + predict_input_fn, predict_keys=['predictions']) + for prediction_dict in debug_predictions: + self.assertTrue('bias' in prediction_dict) + self.assertTrue('dfc' in prediction_dict) + self.assertTrue('predictions' in prediction_dict) + self.assertEqual(len(prediction_dict), 3) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index 724bc2c82f8289bbaa19a1dbbc1dc81b6e158e02..4e7965ef265022214f88ed74f4c8502fc8a4c897 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -118,7 +118,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator): head: A `_Head` instance constructed with a method such as `tf.contrib.estimator.multi_label_head`. model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. linear_feature_columns: An iterable containing all the feature columns used by linear part of the model. All items in the set must be diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py index 152431d1b205845945cc2c079b747f81d739026f..40a91175b71f27bb9ca72a238a5aea172cf4c360 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -24,7 +24,6 @@ import pickle from google.protobuf.any_pb2 import Any from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import dnn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops @@ -68,7 +67,7 @@ def _to_any_wrapped_tensor_info(tensor): return any_buf -def make_input_layer_with_layer_annotations(original_input_layer, mode): +def make_input_layer_with_layer_annotations(original_input_layer): """Make an input_layer replacement function that adds layer annotations.""" def input_layer_with_layer_annotations(features, @@ -76,7 +75,9 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode): weight_collections=None, trainable=True, cols_to_vars=None, - cols_to_output_tensors=None): + scope=None, + cols_to_output_tensors=None, + from_template=False): """Returns a dense `Tensor` as input layer based on given `feature_columns`. Generally a single example in training data is described with @@ -112,9 +113,12 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode): 'some_variable:0' shape=(5, 10), 1. @@ -301,9 +303,9 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name def _model_fn(features, labels, mode, config): with _monkey_patch( - feature_column_lib, 'input_layer', - make_input_layer_with_layer_annotations(feature_column_lib.input_layer, - mode)): + feature_column_lib, '_internal_input_layer', + make_input_layer_with_layer_annotations( + feature_column_lib._internal_input_layer)): # pylint: disable=protected-access return original.model_fn(features, labels, mode, config) return estimator.Estimator( @@ -422,9 +424,9 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name def _model_fn(features, labels, mode, config): with _monkey_patch( - feature_column_lib, 'input_layer', - make_input_layer_with_layer_annotations(feature_column_lib.input_layer, - mode)): + feature_column_lib, '_internal_input_layer', + make_input_layer_with_layer_annotations( + feature_column_lib._internal_input_layer)): # pylint: disable=protected-access return original.model_fn(features, labels, mode, config) return estimator.Estimator( diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py index 3eab21d5acaf26f14a73e7fa8e9c50fffc22fe9c..cafe8279c714bf5d50be61921c9070ca982b99c9 100644 --- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import operator import os @@ -56,6 +57,13 @@ def make_early_stopping_hook(estimator, tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` + Caveat: Current implementation supports early-stopping both training and + evaluation in local mode. In distributed mode, training can be stopped but + evaluation (where it's a separate job) will indefinitely wait for new model + checkpoints to evaluate, so you will need other means to detect and stop it. + Early-stopping evaluation in distributed mode requires changes in + `train_and_evaluate` API and will be addressed in a future revision. + Args: estimator: A `tf.estimator.Estimator` instance. should_stop_fn: `callable`, function that takes no arguments and returns a @@ -108,6 +116,13 @@ def stop_if_higher_hook(estimator, tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` + Caveat: Current implementation supports early-stopping both training and + evaluation in local mode. In distributed mode, training can be stopped but + evaluation (where it's a separate job) will indefinitely wait for new model + checkpoints to evaluate, so you will need other means to detect and stop it. + Early-stopping evaluation in distributed mode requires changes in + `train_and_evaluate` API and will be addressed in a future revision. + Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. @@ -157,6 +172,13 @@ def stop_if_lower_hook(estimator, tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` + Caveat: Current implementation supports early-stopping both training and + evaluation in local mode. In distributed mode, training can be stopped but + evaluation (where it's a separate job) will indefinitely wait for new model + checkpoints to evaluate, so you will need other means to detect and stop it. + Early-stopping evaluation in distributed mode requires changes in + `train_and_evaluate` API and will be addressed in a future revision. + Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. @@ -206,6 +228,13 @@ def stop_if_no_increase_hook(estimator, tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` + Caveat: Current implementation supports early-stopping both training and + evaluation in local mode. In distributed mode, training can be stopped but + evaluation (where it's a separate job) will indefinitely wait for new model + checkpoints to evaluate, so you will need other means to detect and stop it. + Early-stopping evaluation in distributed mode requires changes in + `train_and_evaluate` API and will be addressed in a future revision. + Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. @@ -256,6 +285,13 @@ def stop_if_no_decrease_hook(estimator, tf.estimator.train_and_evaluate(estimator, train_spec, ...) ``` + Caveat: Current implementation supports early-stopping both training and + evaluation in local mode. In distributed mode, training can be stopped but + evaluation (where it's a separate job) will indefinitely wait for new model + checkpoints to evaluate, so you will need other means to detect and stop it. + Early-stopping evaluation in distributed mode requires changes in + `train_and_evaluate` API and will be addressed in a future revision. + Args: estimator: A `tf.estimator.Estimator` instance. metric_name: `str`, metric to track. "loss", "accuracy", etc. @@ -306,7 +342,8 @@ def read_eval_metrics(eval_dir): metrics[value.tag] = value.simple_value if metrics: eval_metrics_dict[event.step] = metrics - return eval_metrics_dict + return collections.OrderedDict( + sorted(eval_metrics_dict.items(), key=lambda t: t[0])) def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold, diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index 66c46e66b77e8819268f7fe084abdc785077f116..49f7bbd32009cc80ef3fa70917ac26a8b752ef6d 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -53,6 +53,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook): ``` Current limitations of this approach are: + * It doesn't support multi-node distributed mode. * It doesn't support saveable objects other than variables (such as boosted tree support) diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index c6c6cad95a7575224c47bb5ec36e243691fed371..62ffad56da324ea3765dfdac64f3ef00d9b17a38 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -294,7 +294,7 @@ class InMemoryEvaluatorHookTest(test.TestCase): def model_fn(features, labels, mode): _, _ = features, labels - w = variables.Variable( + w = variables.VariableV1( initial_value=[0.], trainable=False, collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index ce758992140d43529037b14cbbf958d5aa763fb4..6e793c830244e64cd11c4054918c18a8251be7ac 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -233,6 +233,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None): """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=False) + + def _create_tpu_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): + """See `_Head`.""" + return self._create_estimator_spec( + features=features, mode=mode, logits=logits, labels=labels, + optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=True) + + def _create_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, use_tpu=False): + """Returns `EstimatorSpec` or `TPUEstimatorSpec`.""" if isinstance(logits, dict): logits_dict = logits else: @@ -255,14 +271,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access spec = self._merge_train( all_estimator_spec=all_estimator_spec, optimizer=optimizer, - train_op_fn=train_op_fn) + train_op_fn=train_op_fn, + use_tpu=use_tpu) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec if mode == model_fn.ModeKeys.PREDICT: - return self._merge_predict(all_estimator_spec) + return self._merge_predict(all_estimator_spec, use_tpu=use_tpu) if mode == model_fn.ModeKeys.EVAL: - return self._merge_eval(all_estimator_spec) + return self._merge_eval(all_estimator_spec, use_tpu=use_tpu) raise ValueError('mode={} unrecognized'.format(mode)) def _split_logits(self, logits): @@ -284,28 +301,28 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): - """Merges list of `EstimatorSpec` for training. + def _merge_train( + self, all_estimator_spec, optimizer, train_op_fn, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for training. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. optimizer: `Optimizer` instance to create train op. See `create_estimator_spec` documentation for more details. train_op_fn: Function to create train op. Used if `optimizer` is `None`. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for TRAIN. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for TRAIN. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode. """ losses = [] - metrics = {} for spec in all_estimator_spec: losses.append(spec.loss) - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) if optimizer is not None: if train_op_fn is not None: @@ -317,20 +334,23 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access else: raise ValueError('train_op_fn and optimizer cannot both be None.') - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op, - eval_metric_ops=metrics) + train_op=train_op) - def _merge_predict(self, all_estimator_spec): - """Merges list of `EstimatorSpec` for prediction. + def _merge_predict(self, all_estimator_spec, use_tpu=False): + """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for prediction. Args: - all_estimator_spec: list of `EstimatorSpec` for the individual heads. + all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the + individual heads. + use_tpu: If `True`, returns `TPUEstimatorSpec`. Returns: - `EstimatorSpec` that merges all heads for PREDICT. + `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for PREDICT. """ predictions = {} export_outputs = { @@ -357,20 +377,29 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access export_output_lib.PredictOutput(merged_predict_outputs)) - return model_fn.EstimatorSpec( + spec_type = ( + model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access + return spec_type( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=export_outputs) - def _merge_eval(self, all_estimator_spec): + def _merge_eval(self, all_estimator_spec, use_tpu=False): """Merges list of `EstimatorSpec` for eval. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. + use_tpu: If `True`, will raise `NotImplementedError`, because TPU is not + yet supported for eval. Returns: `EstimatorSpec` that merges all heads for EVAL. + Raises: + NotImplementedError: If `use_tpu` is `True`. """ + if use_tpu: + raise NotImplementedError( + 'TPU evaluation is not implemented for multi_head.') predictions = {} metrics = {} losses = [] diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 2b4d5f526199c500ad77a0422215381ac3a1cf69..a602f87b4a2b4062efddf819522fb2d1eeceaabe 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,7 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads_logits_dict(self): + def _test_predict_two_heads_logits_dict(self, use_tpu): """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -121,10 +121,16 @@ class MultiHeadTest(test.TestCase): 'head2': _sigmoid(logits['head2']), } - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.PREDICT, - logits=logits) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) self.assertItemsEqual( (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', @@ -175,6 +181,12 @@ class MultiHeadTest(test.TestCase): sess.run( spec.export_outputs['head2/predict'].outputs['probabilities'])) + def test_predict_two_heads_logits_dict(self): + self._test_predict_two_heads_logits_dict(use_tpu=False) + + def test_predict_two_heads_logits_dict_tpu(self): + self._test_predict_two_heads_logits_dict(use_tpu=True) + def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') @@ -350,6 +362,31 @@ class MultiHeadTest(test.TestCase): rtol=tol, atol=tol) + def test_eval_tpu(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = { + 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], + dtype=np.float32), + } + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + + with self.assertRaisesRegexp( + NotImplementedError, + r'TPU evaluation is not implemented for multi_head\.'): + multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + def test_train_create_loss_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') multi_head = multi_head_lib.multi_head([head1]) @@ -587,7 +624,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - def test_train_two_heads_with_weights(self): + def _test_train_two_heads_with_weights(self, use_tpu): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head( @@ -619,12 +656,20 @@ class MultiHeadTest(test.TestCase): [constant_op.constant(expected_train_result), string_ops.as_string(loss, precision=3)]) - spec = multi_head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels, - train_op_fn=_train_op_fn) + if use_tpu: + spec = multi_head._create_tpu_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn).as_estimator_spec() + else: + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) self.assertIsNotNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) @@ -649,6 +694,12 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, }, summary_str, tol) + def test_train_two_heads_with_weights(self): + self._test_train_two_heads_with_weights(use_tpu=False) + + def test_train_two_heads_with_weights_tpu(self): + self._test_train_two_heads_with_weights(use_tpu=True) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 98660bb7317ae76a7da7c90a5c890ab8e69037fe..c595f473950e28cd75cd1b56c1b3d409333dbc74 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): return rnn_cell_fn -def _concatenate_context_input(sequence_input, context_input): - """Replicates `context_input` across all timesteps of `sequence_input`. - - Expands dimension 1 of `context_input` then tiles it `sequence_length` times. - This value is appended to `sequence_input` on dimension 2 and the result is - returned. - - Args: - sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, - padded_length, d0]`. - context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. - - Returns: - A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, - d0 + d1]`. - - Raises: - ValueError: If `sequence_input` does not have rank 3 or `context_input` does - not have rank 2. - """ - seq_rank_check = check_ops.assert_rank( - sequence_input, - 3, - message='sequence_input must have rank 3', - data=[array_ops.shape(sequence_input)]) - seq_type_check = check_ops.assert_type( - sequence_input, - dtypes.float32, - message='sequence_input must have dtype float32; got {}.'.format( - sequence_input.dtype)) - ctx_rank_check = check_ops.assert_rank( - context_input, - 2, - message='context_input must have rank 2', - data=[array_ops.shape(context_input)]) - ctx_type_check = check_ops.assert_type( - context_input, - dtypes.float32, - message='context_input must have dtype float32; got {}.'.format( - context_input.dtype)) - with ops.control_dependencies( - [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): - padded_length = array_ops.shape(sequence_input)[1] - tiled_context_input = array_ops.tile( - array_ops.expand_dims(context_input, 1), - array_ops.concat([[1], [padded_length], [1]], 0)) - return array_ops.concat([sequence_input, tiled_context_input], 2) - - def _select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. @@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, context_input = feature_column_lib.input_layer( features=features, feature_columns=context_feature_columns) - sequence_input = _concatenate_context_input(sequence_input, - context_input) + sequence_input = seq_fc.concatenate_context_input( + context_input, sequence_input) cell = rnn_cell_fn(mode) # Ignore output state. diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 1aebed348dcacf8fbe90421bdc7ff25f5b7bcc4a..89506ee6615cd838b0fe651e13eb3e7dd35d2aef 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -25,12 +25,12 @@ import tempfile import numpy as np import six -from tensorflow.contrib.data.python.ops import readers from tensorflow.contrib.estimator.python.estimator import head as head_lib from tensorflow.contrib.estimator.python.estimator import rnn from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.experimental.ops import readers from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import parsing_utils diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 9e1f14f9905d584287864c15d9b6f9c152d17787..e344d7a23b55134612aab430b50cf065bd1095e4 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -64,7 +64,6 @@ tf_custom_op_py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/estimator", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index e076631bc16fd379a2ad31af9055a7388d98c7ca..d365ad111760247fc18b730657390f07ba6b865e 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -154,10 +154,10 @@ class GmmAlgorithm(object): def _create_variables(self): """Initializes GMM algorithm.""" init_value = array_ops.constant([], dtype=dtypes.float32) - self._means = variables.Variable(init_value, - name=self.CLUSTERS_VARIABLE, - validate_shape=False) - self._covs = variables.Variable( + self._means = variables.VariableV1(init_value, + name=self.CLUSTERS_VARIABLE, + validate_shape=False) + self._covs = variables.VariableV1( init_value, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False) # Mixture weights, representing the probability that a randomly # selected unobservable data (in EM terms) was generated by component k. @@ -165,9 +165,9 @@ class GmmAlgorithm(object): array_ops.tile([1.0 / self._num_classes], [self._num_classes]), name=self.CLUSTERS_WEIGHT, validate_shape=False) - self._cluster_centers_initialized = variables.Variable(False, - dtype=dtypes.bool, - name='initialized') + self._cluster_centers_initialized = variables.VariableV1(False, + dtype=dtypes.bool, + name='initialized') def _initialize_variables(self, data, initial_means=None): """Initializes variables. diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 9bdbd050152261daff803e6e71abea93406402ed..75d577f42958d97ccb2632798e86ae059c399cb4 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -420,13 +420,13 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): def test_sweeps(self): - is_row_sweep_var = variables.Variable(True) - is_sweep_done_var = variables.Variable(False) - init_done = variables.Variable(False) - row_prep_done = variables.Variable(False) - col_prep_done = variables.Variable(False) - row_train_done = variables.Variable(False) - col_train_done = variables.Variable(False) + is_row_sweep_var = variables.VariableV1(True) + is_sweep_done_var = variables.VariableV1(False) + init_done = variables.VariableV1(False) + row_prep_done = variables.VariableV1(False) + col_prep_done = variables.VariableV1(False) + row_train_done = variables.VariableV1(False) + col_train_done = variables.VariableV1(False) init_op = state_ops.assign(init_done, True) row_prep_op = state_ops.assign(row_prep_done, True) @@ -486,7 +486,7 @@ class StopAtSweepHookTest(test.TestCase): def test_stop(self): hook = wals_lib._StopAtSweepHook(last_sweep=10) - completed_sweeps = variables.Variable( + completed_sweeps = variables.VariableV1( 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS) train_op = state_ops.assign_add(completed_sweeps, 1) hook.begin() diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index aab7d0c9e8874269bfa5f33193b0dc0ba4bbc9cd..a926ffd5982116a21dc7a0fd1ff957d4ecc6bf94 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -27,6 +27,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", @@ -46,9 +47,29 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "sequence_feature_column_integration_test", + srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/feature_column", + "//tensorflow/python/keras:layers", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 05bcdac2caa77062f9a8a44a948d2897b439ea1f..dd6da35ed009c07ad3819e7860a283c7837c1f83 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope # pylint: disable=protected-access -# TODO(b/73827486): Support SequenceExample. def sequence_input_layer( @@ -110,6 +109,7 @@ def sequence_input_layer( output_tensors = [] sequence_lengths = [] ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): ordered_columns.append(column) with variable_scope.variable_scope( @@ -121,17 +121,67 @@ def sequence_input_layer( # Flattens the final dimension to produce a 3D Tensor. num_elements = column._variable_shape.num_elements() shape = array_ops.shape(dense_tensor) + target_shape = [shape[0], shape[1], num_elements] output_tensors.append( - array_ops.reshape( - dense_tensor, - shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + array_ops.reshape(dense_tensor, shape=target_shape)) sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length +def concatenate_context_input(context_input, sequence_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + def sequence_categorical_column_with_identity( key, num_buckets, default_value=None): """Returns a feature column that represents sequences of integers. @@ -453,9 +503,17 @@ class _SequenceNumericColumn( [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], axis=0) dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) - sequence_length = fc._sequence_length_from_sparse_tensor( - sp_tensor, num_elements=self._variable_shape.num_elements()) + + # Get the number of timesteps per example + # For the 2D case, the raw values are grouped according to num_elements; + # for the 3D case, the grouping happens in the third dimension, and + # sequence length is not affected. + num_elements = (self._variable_shape.num_elements() + if sp_tensor.shape.ndims == 2 else 1) + seq_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=num_elements) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( - dense_tensor=dense_tensor, sequence_length=sequence_length) + dense_tensor=dense_tensor, sequence_length=seq_length) # pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ca363627eace15e039679545366648df174c33 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -0,0 +1,280 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for sequence feature columns with SequenceExamples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string +import tempfile + +from google.protobuf import text_format + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class SequenceFeatureColumnIntegrationTest(test.TestCase): + + def _make_sequence_example(self): + example = example_pb2.SequenceExample() + example.context.feature['int_ctx'].int64_list.value.extend([5]) + example.context.feature['float_ctx'].float_list.value.extend([123.6]) + for val in range(0, 10, 2): + feat = feature_pb2.Feature() + feat.int64_list.value.extend([val] * val) + example.feature_lists.feature_list['int_list'].feature.extend([feat]) + for val in range(1, 11, 2): + feat = feature_pb2.Feature() + feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val) + example.feature_lists.feature_list['str_list'].feature.extend([feat]) + + return example + + def _build_feature_columns(self): + col = fc.categorical_column_with_identity( + 'int_ctx', num_buckets=100) + ctx_cols = [ + fc.embedding_column(col, dimension=10), + fc.numeric_column('float_ctx')] + + identity_col = sfc.sequence_categorical_column_with_identity( + 'int_list', num_buckets=10) + bucket_col = sfc.sequence_categorical_column_with_hash_bucket( + 'bytes_list', hash_bucket_size=100) + seq_cols = [ + fc.embedding_column(identity_col, dimension=10), + fc.embedding_column(bucket_col, dimension=20)] + + return ctx_cols, seq_cols + + def test_sequence_example_into_input_layer(self): + examples = [_make_sequence_example().SerializeToString()] * 100 + ctx_cols, seq_cols = self._build_feature_columns() + + def _parse_example(example): + ctx, seq = parsing_ops.parse_single_sequence_example( + example, + context_features=fc.make_parse_example_spec(ctx_cols), + sequence_features=fc.make_parse_example_spec(seq_cols)) + ctx.update(seq) + return ctx + + ds = dataset_ops.Dataset.from_tensor_slices(examples) + ds = ds.map(_parse_example) + ds = ds.batch(20) + + # Test on a single batch + features = ds.make_one_shot_iterator().get_next() + + # Tile the context features across the sequence features + seq_layer, _ = sfc.sequence_input_layer(features, seq_cols) + ctx_layer = fc.input_layer(features, ctx_cols) + input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + + rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) + output = rnn_layer(input_layer) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + features_r = sess.run(features) + self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) + + output_r = sess.run(output) + self.assertAllEqual(output_r.shape, [20, 10]) + + +class SequenceExampleParsingTest(test.TestCase): + + def test_seq_ex_in_sequence_categorical_column_with_identity(self): + self._test_parsed_sequence_example( + 'int_list', sfc.sequence_categorical_column_with_identity, + 10, [3, 6], [2, 4, 6]) + + def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, + 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, + list(string.ascii_lowercase), [3, 4], + [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): + _, fname = tempfile.mkstemp() + with open(fname, 'w') as f: + f.write(string.ascii_lowercase) + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, + fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def _test_parsed_sequence_example( + self, col_name, col_fn, col_arg, shape, values): + """Helper function to check that each FeatureColumn parses correctly. + + Args: + col_name: string, name to give to the feature column. Should match + the name that the column will parse out of the features dict. + col_fn: function used to create the feature column. For example, + sequence_numeric_column. + col_arg: second arg that the target feature column is expecting. + shape: the expected dense_shape of the feature after parsing into + a SparseTensor. + values: the expected values at index [0, 2, 6] of the feature + after parsing into a SparseTensor. + """ + example = _make_sequence_example() + columns = [ + fc.categorical_column_with_identity('int_ctx', num_buckets=100), + fc.numeric_column('float_ctx'), + col_fn(col_name, col_arg) + ] + context, seq_features = parsing_ops.parse_single_sequence_example( + example.SerializeToString(), + context_features=fc.make_parse_example_spec(columns[:2]), + sequence_features=fc.make_parse_example_spec(columns[2:])) + + with self.cached_session() as sess: + ctx_result, seq_result = sess.run([context, seq_features]) + self.assertEqual(list(seq_result[col_name].dense_shape), shape) + self.assertEqual( + list(seq_result[col_name].values[[0, 2, 6]]), values) + self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) + self.assertEqual(ctx_result['int_ctx'].values[0], 5) + self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) + self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) + + +_SEQ_EX_PROTO = """ +context { + feature { + key: "float_ctx" + value { + float_list { + value: 123.6 + } + } + } + feature { + key: "int_ctx" + value { + int64_list { + value: 5 + } + } + } +} +feature_lists { + feature_list { + key: "bytes_list" + value { + feature { + bytes_list { + value: "a" + } + } + feature { + bytes_list { + value: "b" + value: "c" + } + } + feature { + bytes_list { + value: "d" + value: "e" + value: "f" + value: "g" + } + } + } + } + feature_list { + key: "float_list" + value { + feature { + float_list { + value: 1.0 + } + } + feature { + float_list { + value: 3.0 + value: 3.0 + value: 3.0 + } + } + feature { + float_list { + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + } + } + } + } + feature_list { + key: "int_list" + value { + feature { + int64_list { + value: 2 + value: 2 + } + } + feature { + int64_list { + value: 4 + value: 4 + value: 4 + value: 4 + } + } + feature { + int64_list { + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + } + } + } + } +} +""" + + +def _make_sequence_example(): + example = example_pb2.SequenceExample() + return text_format.Parse(_SEQ_EX_PROTO, example) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 45d7b740462ca21139e2e93e34b43668f1e08a94..707f93b2da5d24a3c1e5c6097a21d8fed4c11b8b 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc @@ -28,28 +29,63 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session -class SequenceInputLayerTest(test.TestCase): +class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [2, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[2], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 2, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]], + # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_embedding_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): - def test_embedding_column(self): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) vocabulary_size = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [2, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - embedding_dimension_a = 2 embedding_values_a = ( (1., 2.), # id 0 @@ -70,14 +106,6 @@ class SequenceInputLayerTest(test.TestCase): return embedding_values return _initializer - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [2, 0] - [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], - ] - expected_sequence_length = [1, 2] - categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column_a = fc.embedding_column( @@ -233,29 +261,56 @@ class SequenceInputLayerTest(test.TestCase): }, feature_columns=shared_embedding_columns) - def test_indicator_column(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [1, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 1, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[1], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 1, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]], + # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -] + [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_indicator_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + vocabulary_size_a = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) vocabulary_size_b = 2 - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [1, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 1, 0), - dense_shape=(2, 2)) - - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [1, 0] - [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], - ] - expected_sequence_length = [1, 2] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) @@ -298,18 +353,34 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[indicator_column_a]) - def test_numeric_column(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_input_layer = [ - [[0.], [1.]], - [[10.], [0.]], - ] - expected_sequence_length = [2, 1] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + [[0.], [1.]], + [[10.], [0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_numeric_column( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa') input_layer, sequence_length = sfc.sequence_input_layer( @@ -321,21 +392,40 @@ class SequenceInputLayerTest(test.TestCase): self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) - def test_numeric_column_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.] + # example 1, [10., 11., 12., 13.] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + ) + def test_numeric_column_multi_dim( + self, sparse_input_args, expected_input_layer, expected_sequence_length): """Tests sequence_input_layer for multi-dimensional numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - # The output of numeric_column._get_dense_tensor should be flattened. - expected_input_layer = [ - [[0., 1., 2., 3.], [4., 5., 6., 7.]], - [[10., 11., 12., 13.], [0., 0., 0., 0.]], - ] - expected_sequence_length = [2, 1] + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -377,6 +467,138 @@ class SequenceInputLayerTest(test.TestCase): r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): sess.run(sequence_length) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_shape': [2, 2, 4]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_shape': [2, 2, 4]}, + ) + def test_static_shape_from_tensors_numeric( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected_shape': [4, 2, 3]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected_shape': [4, 2, 3]} + ) + def test_static_shape_from_tensors_indicator( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=3) + indicator_column = fc.indicator_column(categorical_column) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, feature_columns=[indicator_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + +class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): + """Tests the utility fn concatenate_context_input.""" + + def test_concatenate_context_input(self): + seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) + context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + input_layer = sfc.concatenate_context_input(context_input, seq_input) + + expected = np.array([ + [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], + [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] + ], dtype=np.float32) + with monitored_session.MonitoredSession() as sess: + output = sess.run(input_layer) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_3', + 'seq_input_arg': np.arange(100).reshape(10, 10)}, + {'testcase_name': 'rank_gt_3', + 'seq_input_arg': np.arange(100).reshape(5, 5, 2, 2)} + ) + def test_sequence_input_throws_error(self, seq_input_arg): + seq_input = ops.convert_to_tensor(seq_input_arg) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'): + sfc.concatenate_context_input(context_input, seq_input) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_2', + 'context_input_arg': np.arange(100)}, + {'testcase_name': 'rank_gt_2', + 'context_input_arg': np.arange(100).reshape(5, 5, 4)} + ) + def test_context_input_throws_error(self, context_input_arg): + context_input = ops.convert_to_tensor(context_input_arg) + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_seq_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'sequence_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_context_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'context_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + class InputLayerTest(test.TestCase): """Tests input_layer with sequence feature columns.""" @@ -443,75 +665,83 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) -class SequenceCategoricalColumnWithIdentityTest(test.TestCase): - - def test_get_sparse_tensors(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((1, 2, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) +class SequenceCategoricalColumnWithIdentityTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((1, 2, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - def test_get_sparse_tensors_inputs3d(self): - """Tests _get_sparse_tensors when the input is already 3D Tensor.""" - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=(1, 2, 0), - dense_shape=(2, 2, 1)) - - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r'Column aaa expected ID tensor of rank 2\.\s*' - r'id_tensor shape:\s*\[2 2 1\]'): - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': inputs})) - with monitored_session.MonitoredSession() as sess: - id_weight_pair.id_tensor.eval(session=sess) - - -class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): - - def test_get_sparse_tensors(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithHashBucketTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('omar', 'stringer', 'marlo'), - dense_shape=(2, 2)) - - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - # Ignored to avoid hash dependence in test. - values=np.array((0, 0, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_indices_shape( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) + self, expected, id_weight_pair.id_tensor.eval(session=sess)) -class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): +class SequenceCategoricalColumnWithVocabularyFileTest( + test.TestCase, parameterized.TestCase): def _write_vocab(self, vocab_strings, file_name): vocab_file = os.path.join(self.get_temp_dir(), file_name) @@ -527,68 +757,125 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): 'wire_vocabulary.txt') self._wire_vocabulary_size = 3 - def test_get_sparse_tensors(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_vocabulary_file( key='aaa', vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - -class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): - - def test_get_sparse_tensors(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyListTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((2, -1, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - -class SequenceEmbeddingColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[7., 11.], [0., 0.]], + # example 1, ids [[0, 1], [2]] + [[2, 3.5], [7., 11.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [[1], [0, 2]] + [[3., 5.], [4., 6.5]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 @@ -601,17 +888,6 @@ class SequenceEmbeddingColumnTest(test.TestCase): self.assertIsNone(partition_info) return embedding_values - expected_lookups = [ - # example 0, ids [2] - [[7., 11.], [0., 0.]], - # example 1, ids [0, 1] - [[1., 2.], [3., 5.]], - # example 2, ids [] - [[0., 0.], [0., 0.]], - # example 3, ids [1] - [[3., 5.], [0., 0.]], - ] - categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column = fc.embedding_column( @@ -619,24 +895,36 @@ class SequenceEmbeddingColumnTest(test.TestCase): initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) - self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, embedding_lookup.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) @@ -644,7 +932,7 @@ class SequenceEmbeddingColumnTest(test.TestCase): categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -855,56 +1143,89 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b, sequence_length_b.eval(session=sess)) -class SequenceIndicatorColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): +class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [2, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 2, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [[0, 1], [2]] + [[1., 1., 0.], [0., 0., 1.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [[1], [2, 2]] + [[0., 1., 0.], [0., 0., 2.]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - - expected_lookups = [ - # example 0, ids [2] - [[0., 0., 1.], [0., 0., 0.]], - # example 1, ids [0, 1] - [[1., 0., 0.], [0., 1., 0.]], - # example 2, ids [] - [[0., 0., 0.], [0., 0., 0.]], - # example 3, ids [1] - [[0., 1., 0.], [0., 0., 0.]], - ] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) - - def test_sequence_length(self): + self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -938,7 +1259,7 @@ class SequenceIndicatorColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) -class SequenceNumericColumnTest(test.TestCase): +class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): a = sfc.sequence_numeric_column('aaa') @@ -971,25 +1292,37 @@ class SequenceNumericColumnTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'must be a callable'): sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') - def test_get_sequence_dense_tensor(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_dense_tensor = [ - [[0.], [1.]], - [[10.], [0.]], - ] + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected': [ + [[0.], [1.]], + [[10.], [0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]]}, + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) numeric_column = sfc.sequence_numeric_column('aaa') dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) + self.assertAllEqual(expected, dense_tensor.eval(session=sess)) def test_get_sequence_dense_tensor_with_normalizer_fn(self): @@ -1026,41 +1359,35 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_get_sequence_dense_tensor_with_shape(self): - """Tests get_sequence_dense_tensor with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_dense_tensor = [ - [[0., 1., 2.], [3., 4., 5.]], - [[10., 11., 12.], [0., 0., 0.]], - ] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_get_dense_tensor_multi_dim(self): + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_dense_tensor': [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]}, + {'testcase_name': '3D', + 'sparse_input_args': { + 'indices': ((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6), + (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6), + (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 8)}, + 'expected_dense_tensor': [ + [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]], + [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]], + [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]], + [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]}, + ) + def test_get_dense_tensor_multi_dim( + self, sparse_input_args, expected_dense_tensor): """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - expected_dense_tensor = [ - [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], - [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], - ] + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) dense_tensor, _ = numeric_column._get_sequence_dense_tensor( @@ -1070,43 +1397,56 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) - def test_sequence_length(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '2D_with_shape', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 1], + 'shape': (2,)}, + {'testcase_name': '3D_with_shape', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (2,)}, + ) + def test_sequence_length(self, inputs_args, expected_sequence_length, shape): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=shape) _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _LazyBuilder({'aaa': inputs})) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) self.assertAllEqual(expected_sequence_length, sequence_length) self.assertEqual(np.int64, sequence_length.dtype) - def test_sequence_length_with_shape(self): - """Tests _sequence_length with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa') - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index b1820c10c8a83cd73143931ba4a1cb210851d86a..9b0b9b1e1bf51db9332806097c2b3ae14d0587ad 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -186,7 +186,7 @@ class WithShapeTest(test.TestCase): unexpected_shapes) def test_with_shape_2x2_with_partial_expected_shape(self): - with self.test_session(): + with self.cached_session(): value = [[42, 43], [44, 45]] actual_shape = [2, 2] tensor = constant_op.constant(value, shape=actual_shape) diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index f9b0efd1daaee42be1043b100edeb327d253d6f8..c223df5b6e944a19aa949b726e89daa9f0cb6cc8 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -192,7 +192,7 @@ class GlobalStepTest(test.TestCase): def test_invalid_dtype(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( 0.0, trainable=False, dtype=dtypes.float32, @@ -205,7 +205,7 @@ class GlobalStepTest(test.TestCase): def test_invalid_shape(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( [0], trainable=False, dtype=dtypes.int32, @@ -229,7 +229,7 @@ class GlobalStepTest(test.TestCase): def test_get_global_step(self): with ops.Graph().as_default() as g: self.assertEquals(None, variables_lib2.get_global_step()) - variables_lib.Variable( + variables_lib.VariableV1( 0, trainable=False, dtype=dtypes.int32, @@ -607,10 +607,10 @@ class ModelVariablesTest(test.TestCase): with self.cached_session(): with variable_scope.variable_scope('A'): variables_lib2.local_variable([5]) - a = variables_lib.Variable([5]) + a = variables_lib.VariableV1([5]) with variable_scope.variable_scope('B'): variables_lib2.local_variable([5]) - b = variables_lib.Variable([5]) + b = variables_lib.VariableV1([5]) self.assertEquals([a], variables_lib2.get_trainable_variables('A')) self.assertEquals([b], variables_lib2.get_trainable_variables('B')) @@ -953,7 +953,7 @@ class AssignFromCheckpointTest(test.TestCase): # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] - var_list.append(variables_lib.Variable(var_value, name=var_name)) + var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) @@ -1106,7 +1106,7 @@ class AssignFromCheckpointFnTest(test.TestCase): # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] - var_list.append(variables_lib.Variable(var_value, name=var_name)) + var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) @@ -1297,7 +1297,7 @@ class AssignFromCheckpointFnTest(test.TestCase): class ZeroInitializerOpTest(test.TestCase): def _testZeroInitializer(self, shape, initializer, use_init): - var = variables_lib.Variable(initializer) + var = variables_lib.VariableV1(initializer) var_zero = variables_lib2.zero_initializer(var) with self.cached_session() as sess: with self.assertRaisesOpError('Attempting to use uninitialized value'): @@ -1350,12 +1350,12 @@ class FilterVariablesTest(test.TestCase): g = ops.Graph() with g.as_default(): var_list = [] - var_list.append(variables_lib.Variable(0, name='conv1/weights')) - var_list.append(variables_lib.Variable(0, name='conv1/biases')) - var_list.append(variables_lib.Variable(0, name='conv2/weights')) - var_list.append(variables_lib.Variable(0, name='conv2/biases')) - var_list.append(variables_lib.Variable(0, name='clfs/weights')) - var_list.append(variables_lib.Variable(0, name='clfs/biases')) + var_list.append(variables_lib.VariableV1(0, name='conv1/weights')) + var_list.append(variables_lib.VariableV1(0, name='conv1/biases')) + var_list.append(variables_lib.VariableV1(0, name='conv2/weights')) + var_list.append(variables_lib.VariableV1(0, name='conv2/biases')) + var_list.append(variables_lib.VariableV1(0, name='clfs/weights')) + var_list.append(variables_lib.VariableV1(0, name='clfs/biases')) self._var_list = var_list def _test_filter_variables(self, diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 0f0813c07f8bd330b089780064e02f8dfe7d49f6..57a5bfbf43c915775c6b0ef05baac19581213a09 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -17,11 +17,14 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_kernel_library", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_custom_op_py_library( @@ -109,13 +112,13 @@ tf_gen_op_wrapper_py( deps = [":fused_conv2d_bias_activation_op_op_lib"], ) -cuda_py_test( - name = "fused_conv2d_bias_activation_op_test", - size = "large", - srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"], - additional_deps = [ +py_library( + name = "fused_conv2d_bias_activation_op_test_base", + testonly = 1, + srcs = ["python/ops/fused_conv2d_bias_activation_op_test_base.py"], + visibility = ["//tensorflow/compiler/tf2xla:internal"], + deps = [ ":fused_conv_py", - "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -128,16 +131,28 @@ cuda_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "fused_conv2d_bias_activation_op_test", + size = "large", + srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"], + additional_deps = [ + ":fused_conv2d_bias_activation_op_test_base", + "//tensorflow/python:client_testlib", ], tags = [ - "manual", - "requires_cudnn6", + "manual", # TODO(b/117128481): re-enable after fixing OSS build + "no_pip", + "requires-gpu-sm70", ], ) cuda_py_test( name = "fused_conv2d_bias_activation_benchmark", - size = "large", srcs = ["python/ops/fused_conv2d_bias_activation_benchmark.py"], additional_deps = [ ":fused_conv_py", @@ -155,7 +170,7 @@ cuda_py_test( ], main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ - "manual", - "requires_cudnn6", + "manual", # TODO(b/117128481): re-enable after fixing OSS build + "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 716bb87e3883d682f840af73d0eb3b013c411348..93b1aaa85e88e00c1b12a388321a4d6fb10f1611 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -111,8 +111,8 @@ class FusedConv2DBiasActivationOp : public OpKernel { context, (GetTensorDim(strides, data_format_, 'N') == 1 && GetTensorDim(strides, data_format_, 'C') == 1), - errors::InvalidArgument("Convolutional strides are not supported in " - "the batch or depth dimensions.")); + errors::Unimplemented("Convolutional strides are not supported in " + "the batch and depth dimensions.")); // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. constexpr bool is_int8x4 = std::is_same::value; @@ -497,7 +497,8 @@ void LaunchFusedConv2DBiasActivationOp:: FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO), &maybe_transformed_filter)); functor::TransformFilter()( - ctx->eigen_device(), To32Bit(filter_param.tensor()), + ctx->eigen_device(), FORMAT_OIHW, + To32Bit(filter_param.tensor()), To32Bit(maybe_transformed_filter.tensor())); filter = &maybe_transformed_filter; } diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 0185ef662c2ed05b1ceaf0e3e8071bad4c0d1a0a..e5c8a34fc14b01d6f6c9bdca065e96332ed87556 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -12,898 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functional tests for fused conv2d bias and activation operation.""" + +"""Tests for fused convolutions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops +from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op_test_base as test_base from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging - - -def GetShrunkInceptionShapes(shrink=10): - """Iterator for smaller versions of convolution shapes in 2015 Inception. - - Relative to inception, each depth value is `depth // shrink`. - - Args: - shrink: Factor to shrink each depth value by relative to Inception. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the convolution - parameters of Inception layers. - """ - input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [ - 4, 8, 8, 2048 - ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [ - 4, 8, 8, 1760 - ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [ - 4, 17, 17, 192 - ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [ - 4, 17, 17, 192 - ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [ - 4, 17, 17, 192 - ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [ - 4, 17, 17, 160 - ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024], - [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [ - 4, 17, 17, 768 - ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768], - [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [ - 4, 35, 35, 64 - ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [ - 4, 35, 35, 256 - ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [ - 4, 35, 35, 192 - ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]] - filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [ - 1, 1, 2048, 192 - ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384], - [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [ - 1, 1, 1760, 320 - ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [ - 3, 3, 128, 320 - ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [ - 1, 3, 192, 256 - ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [ - 3, 3, 192, 224 - ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [ - 3, 1, 192, 192 - ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [ - 1, 3, 128, 192 - ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [ - 3, 1, 128, 128 - ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [ - 1, 1, 768, 128 - ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [ - 3, 3, 64, 96 - ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64], - [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [ - 1, 1, 192, 64 - ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64, - 64], [1, 1, 24, 64]] - out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [ - 4, 8, 8, 384 - ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [ - 4, 8, 8, 192 - ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [ - 4, 17, 17, 192 - ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [ - 4, 17, 17, 256 - ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [ - 4, 17, 17, 192 - ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [ - 4, 17, 17, 160 - ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [ - 4, 17, 17, 256 - ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [ - 4, 17, 17, 128 - ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [ - 4, 35, 35, 64 - ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96], - [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48], - [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]] - strides = [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1 - ] - # Shrink sizes to make the test faster - for i in input_sizes: - i[3] //= shrink - for f in filter_sizes: - f[2] //= shrink - f[3] //= shrink - for o in out_sizes: - o[3] //= shrink - # pylint: disable=invalid-name - VALID = "VALID" - SAME = "SAME" - # pylint: enable=invalid-name - paddings = [ - SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME, - SAME, SAME, SAME, SAME, VALID, VALID, VALID - ] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - - -def GetTestConfigs(): - """Get all the valid tests configs to run. - - Returns: - all the valid test configs as tuples of data_format and use_gpu. - """ - test_configs = [("NCHW", True), ("NHWC", True)] - return test_configs - - -class FusedConv2DBiasActivationTest(test.TestCase): - - def _DtypesToTest(self, use_gpu): - return [dtypes.float32] - - def _FilterFormatsToTest(self, use_gpu): - return ["HWIO", "OIHW"] - - def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias, - strides, padding, activation_mode, data_format, - filter_format, dtype): - """Verifies the output values of the convolution function. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [kernel_rows, kernel_cols, input_depth, output_depth]. - bias: 1-D bias tensor of length output_depth. - strides: Stride: [col_stride, row_stride] - padding: Padding type. - activation_mode: Activation mode. - data_format: Format of the data tensors. - filter_format: Filter format to use for the fused convolution. - dtype: Data type for inputs and outputs. - Returns: - Symbolic tensor value and reference value that can be used to - execute the computation and verify the results. - """ - input_size = np.prod(tensor_in_sizes) - filter_size = np.prod(filter_in_sizes) - bias_size = filter_in_sizes[-1] # equals to output depth - # Initializes the input tensor with array containing incrementing - # numbers from 1. - x1 = [f * 1.0 for f in range(1, input_size + 1)] - x2 = [f * 1.0 for f in range(1, filter_size + 1)] - # This is to guarantee that there is always negative values after - # bias add so that we can test whether relu works correctly. - x3 = bias - with self.test_session(use_gpu=True): - t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) - t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) - fused_t2 = t2 - if filter_format == "OIHW": - fused_t2 = HwioToOihw(t2) - t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) - strides = [1] + strides + [1] - if data_format == "NCHW": - t1 = test_util.NHWCToNCHW(t1) - strides = test_util.NHWCToNCHW(strides) - output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - t1, - fused_t2, - t3, - strides=strides, - padding=padding, - data_format=data_format, - filter_format=filter_format, - activation_mode=activation_mode) - ref_conv_output = nn_ops.conv2d( - t1, t2, strides=strides, padding=padding, data_format=data_format) - ref_bias_output = nn_ops.bias_add( - ref_conv_output, t3, data_format=data_format) - ref_output = nn_ops.relu(ref_bias_output) - if data_format == "NCHW": - output = test_util.NCHWToNHWC(output) - ref_output = test_util.NCHWToNHWC(ref_output) - - return output, ref_output - - def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides, - padding): - """Verifies that CPU and GPU produce the same values. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [kernel_rows, kernel_cols, input_depth, output_depth]. - conv_strides: [row_stride, col_stride] for the convolution; - padding: Padding type. - """ - x1 = np.random.rand(*tensor_in_sizes).astype(np.float32) - x2 = np.random.rand(*filter_in_sizes).astype(np.float32) - x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) - - def _SetupVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) - strides = [1] + conv_strides + [1] - if data_format == "NCHW": - t1 = test_util.NHWCToNCHW(t1) - strides = test_util.NHWCToNCHW(strides) - output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - t1, - t2, - t3, - strides=strides, - padding=padding, - data_format=data_format, - activation_mode="Relu") - - if data_format == "NCHW": - output = test_util.NCHWToNHWC(output) - return output - - tensors = [] - for (data_format, use_gpu) in GetTestConfigs(): - tensors.append(_SetupVal(data_format, use_gpu)) - with self.test_session() as sess: - values = sess.run(tensors) - for i in range(1, len(values)): - self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5) - - def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides, - padding): - tensors = [] - ref_tensors = [] - for (data_format, use_gpu) in GetTestConfigs(): - for dtype in self._DtypesToTest(use_gpu): - for filter_format in self._FilterFormatsToTest(use_gpu): - result, expected = self._SetupValuesForDevice( - tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", - data_format, filter_format, dtype) - tensors.append(result) - ref_tensors.append(expected) - with self.test_session() as sess: - values = sess.run(tensors) - ref_values = sess.run(ref_tensors) - for i in range(len(tensors)): - conv = tensors[i] - value = values[i] - ref_value = ref_values[i] - tf_logging.info("expected = ", ref_value) - tf_logging.info("actual = ", value) - tol = 1e-5 - if value.dtype == np.float16: - tol = 1e-3 - self.assertAllClose( - np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol) - self.assertShapeEqual(value, conv) - - def testConv2D1x1Filter(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D1x1Filter test.") - return - # expected_output = [ - # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0, - # 86.0, 43.0, 165.0, 131.0, 97.0 - # ] - medians = [-45.0, -130.0, -215.0] - self._VerifyValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[1, 1, 3, 3], - bias=medians, - strides=[1, 1], - padding="VALID") - - def testConv2DEmpty(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2DEmpty test.") - return - # expected_output = [] - self._VerifyValues( - tensor_in_sizes=[0, 2, 3, 3], - filter_in_sizes=[1, 1, 3, 3], - bias=[0.0, 0.0, 0.0], - strides=[1, 1], - padding="VALID") - - def testConv2D2x2Filter(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D2x2Filter test.") - return - # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0] - self._VerifyValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - bias=[-2500.0, -2500.0, -2500.0], - strides=[1, 1], - padding="VALID") - - def testConv2D1x2Filter(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D1x2Filter test.") - return - # expected_output = [ - # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0 - # ] - self._VerifyValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[1, 2, 3, 3], - bias=[-500.0, -500.0, -500.0], - strides=[1, 1], - padding="VALID") - - def testConv2D2x2FilterStride2(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D2x2FilterStride2 test.") - return - # expected_output = [0.0, 67.0, 163.0] - self._VerifyValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - bias=[-2300.0, -2300.0, -2300.0], - strides=[2, 2], - padding="VALID") - - def testConv2D2x2FilterStride2Same(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.") - return - # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] - self._VerifyValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - bias=[-2300.0, -1000.0, -1000.0], - strides=[2, 2], - padding="SAME") - - def testConv2D2x2FilterStride1x2(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.") - return - # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0] - self._VerifyValues( - tensor_in_sizes=[1, 3, 6, 1], - filter_in_sizes=[2, 2, 1, 1], - bias=[-90.0], - strides=[1, 2], - padding="VALID") - - def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.") - return - # expected_output = [0, 0, 175, 205] - self._VerifyValues( - tensor_in_sizes=[1, 7, 7, 1], - filter_in_sizes=[2, 2, 1, 1], - bias=[-100.0], - strides=[3, 3], - padding="VALID") - - def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.") - return - # expected = [0, 0, 2, 4] - self._VerifyValues( - tensor_in_sizes=[1, 3, 3, 1], - filter_in_sizes=[1, 1, 1, 1], - bias=[-5.0], - strides=[2, 2], - padding="SAME") - - # expected = [0, 0, 4, 6] - self._VerifyValues( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[1, 1, 1, 1], - bias=[-5.0], - strides=[2, 2], - padding="SAME") - - # expected = [4, 0, 1, 0] - self._VerifyValues( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - bias=[-40.0], - strides=[3, 3], - padding="SAME") - - def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.") - return - # expected = [0, 5] - self._VerifyValues( - tensor_in_sizes=[1, 2, 2, 1], - filter_in_sizes=[2, 2, 1, 2], - bias=[-50.0, -55.0], - strides=[1, 1], - padding="VALID") - - # expected = [0, 2, 282, 322] - self._VerifyValues( - tensor_in_sizes=[1, 8, 8, 1], - filter_in_sizes=[2, 2, 1, 1], - bias=[-200.0], - strides=[4, 4], - padding="SAME") - - def testShapeFunctionEdgeCases(self): - # All shapes unknown. - c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 1], - padding="SAME", - activation_mode="Relu") - self.assertEqual([None, None, None, None], c1.get_shape().as_list()) - - # Incorrect input shape. - with self.assertRaises(ValueError): - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32, shape=[1, 3]), - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 1], - padding="SAME", - activation_mode="Relu") - - # Incorrect filter shape. - with self.assertRaises(ValueError): - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32, shape=[1, 3]), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 1], - padding="SAME", - activation_mode="Relu") - - # Depth mismatch. - with self.assertRaises(ValueError): - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), - array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 1], - padding="SAME", - activation_mode="Relu") - - def testOpEdgeCases(self, gpu_only=True): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping OpEdgeCases tests.") - return - with self.test_session() as sess: - # Illegal strides. - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Convolutional strides are not supported in " - "the batch or depth dimensions."): - sess.run( - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - strides=[2, 1, 1, 1], - padding="SAME", - activation_mode="Relu")) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Convolutional strides are not supported in " - "the batch or depth dimensions."): - sess.run( - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 2], - padding="SAME", - activation_mode="Relu")) - - # Illegal activation mode. - with self.assertRaisesRegexp(ValueError, - "Op passed string 'Tanh' not in:"): - sess.run( - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.float32), - strides=[1, 1, 1, 1], - padding="SAME", - activation_mode="Tanh")) - - # Filter larger than input. - with self.assertRaisesRegexp(ValueError, "Negative dimension size"): - sess.run( - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), - array_ops.placeholder(dtypes.float32, shape=[20, 21, 3, 2]), - array_ops.placeholder(dtypes.float32, shape=[2]), - strides=[1, 1, 1, 1], - padding="VALID", - activation_mode="Relu")) - with self.assertRaisesRegexp(ValueError, "Negative dimension size"): - sess.run( - fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), - array_ops.placeholder(dtypes.float32, shape=[21, 20, 3, 2]), - array_ops.placeholder(dtypes.float32, shape=[2]), - strides=[1, 1, 1, 1], - padding="VALID", - activation_mode="Relu")) - - -def GetInceptionFwdTest(input_size, filter_size, stride, padding, - gpu_only=True): - - def Test(self): - if gpu_only and not test.is_gpu_available(): - tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size, - stride, padding)) - return - tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride, - padding)) - self._CompareFwdValues(input_size, filter_size, [stride, stride], padding) - - return Test - - -def CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type): - """Calculates the size of an output dimension of a strided convolution. - - Given the sizes of the corresponding dimension of the input and filter shapes, - and the stride and padding_types, calculates the size of the output dimension. - This function can be called separately for each input dimension. - - Args: - input_dim: An `int` specifying the size of the input dimension. - filter_dim: An `int` specifying the size of the filter dimension. - stride: An `int` specifying the step size of the convolution along the - input dimension. - padding_type: either 'VALID' or 'SAME'. - - Returns: - The size of the output dimension. - """ - if padding_type == "VALID": - return (input_dim - filter_dim + stride) // stride - else: # padding_type == 'SAME' - return (input_dim + stride - 1) // stride - - -def NchwVectCToNchw(in_tensor): - # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] - t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) - n = in_tensor.shape.dims[0].value - c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value - h = in_tensor.shape.dims[2].value - w = in_tensor.shape.dims[3].value - return array_ops.reshape(t, [n, c, h, w]) - - -def OihwVectIToHwio(in_tensor): - # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] - t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) - o = in_tensor.shape.dims[0].value - i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value - h = in_tensor.shape.dims[2].value - w = in_tensor.shape.dims[3].value - return array_ops.reshape(t, [h, w, i, o]) - - -def NchwToNchwVectC(in_tensor): - n, c, h, w = in_tensor.shape.as_list() - assert c % 4 == 0 - t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) - return array_ops.transpose(t, [0, 1, 3, 4, 2]) - - -def HwioToOihw(in_tensor): - return array_ops.transpose(in_tensor, [3, 2, 0, 1]) - - -def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, - padding, strides, side_input_scale, - side_input, biases, apply_relu): - """Simulates the int8 fused 2-D convolution op using separate float ops. - - The arguments and return values have the same format, meanings and - restrictions as the actual op. - Args: - conv_input_scale: A scalar 'float'. - conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. - kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout. - padding: A `string` from: `"SAME", "VALID"`. - strides: A list of `ints`. - side_input_scale: A scalar 'float'. - side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. - biases: A `Tensor` of type `float32` in NCHW layout. - apply_relu: A boolean to specify whether to apply "Relu" activation function - that clips outputs to the range [0, 127], or "None" activation that clips - to the range [-128, 127]. - Returns: - A `Tensor` of type `qint8` in NCHW_VECT_C layout. - """ - conv_result = nn_ops.conv2d( - NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)), - OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)), - strides=strides, - padding=padding, - data_format="NCHW") * conv_input_scale - - conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw( - gen_array_ops.dequantize(side_input, -128, 127)) - - output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") - if apply_relu: - output = nn_ops.relu(output) - - result, _, _ = gen_array_ops.quantize_v2( - NchwToNchwVectC(output), -128, 127, dtypes.qint8) - return result - - -class FusedConvInt8Tests(test.TestCase): - _test_params = [ - { - "batch_size": 1, - "input_channels": 4, - "output_channels": 4, - "input_height": 8, - "input_width": 8, - "filter_height": 6, - "filter_width": 6, - "vertical_stride": 2, - "horizontal_stride": 2, - "conv_input_scale": 0.002, - "side_input_scale": 0.0, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 1, - "input_channels": 4, - "output_channels": 4, - "input_height": 6, - "input_width": 6, - "filter_height": 6, - "filter_width": 6, - "vertical_stride": 2, - "horizontal_stride": 2, - "conv_input_scale": 0.002, - "side_input_scale": 0.0, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 2, - "input_channels": 8, - "output_channels": 16, - "input_height": 8, - "input_width": 8, - "filter_height": 3, - "filter_width": 3, - "vertical_stride": 2, - "horizontal_stride": 2, - "conv_input_scale": 0.002, - "side_input_scale": 0.0, - "bias_scale": 1, - "padding_type": "VALID" - }, - { - "batch_size": 2, - "input_channels": 8, - "output_channels": 16, - "input_height": 8, - "input_width": 8, - "filter_height": 3, - "filter_width": 3, - "vertical_stride": 2, - "horizontal_stride": 2, - "conv_input_scale": 0.002, - "side_input_scale": 0.0, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 2, - "input_channels": 8, - "output_channels": 16, - "input_height": 8, - "input_width": 8, - "filter_height": 3, - "filter_width": 3, - "vertical_stride": 2, - "horizontal_stride": 2, - "conv_input_scale": 0.002, - "side_input_scale": 0.5, - "bias_scale": 1, - "padding_type": "VALID" - }, - { - "batch_size": 2, - "input_channels": 16, - "output_channels": 16, - "input_height": 9, - "input_width": 9, - "filter_height": 3, - "filter_width": 3, - "vertical_stride": 1, - "horizontal_stride": 1, - "conv_input_scale": 0.001, - "side_input_scale": 0.5, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 3, - "input_channels": 8, - "output_channels": 8, - "input_height": 9, - "input_width": 9, - "filter_height": 5, - "filter_width": 5, - "vertical_stride": 1, - "horizontal_stride": 1, - "conv_input_scale": 0.001, - "side_input_scale": 0.5, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 3, - "input_channels": 8, - "output_channels": 8, - "input_height": 9, - "input_width": 9, - "filter_height": 7, - "filter_width": 1, - "vertical_stride": 2, - "horizontal_stride": 1, - "conv_input_scale": 0.002, - "side_input_scale": 0.5, - "bias_scale": 1, - "padding_type": "SAME" - }, - { - "batch_size": 3, - "input_channels": 8, - "output_channels": 8, - "input_height": 9, - "input_width": 9, - "filter_height": 1, - "filter_width": 7, - "vertical_stride": 1, - "horizontal_stride": 1, - "conv_input_scale": 0.002, - "side_input_scale": 0.5, - "bias_scale": 1, - "padding_type": "SAME" - }, - ] - - def runTest(self, test_param, apply_relu): - batch_size = test_param["batch_size"] - input_channels = test_param["input_channels"] - output_channels = test_param["output_channels"] - input_height = test_param["input_height"] - input_width = test_param["input_width"] - filter_height = test_param["filter_height"] - filter_width = test_param["filter_width"] - vertical_stride = test_param["vertical_stride"] - horizontal_stride = test_param["horizontal_stride"] - conv_input_scale = test_param["conv_input_scale"] - side_input_scale = test_param["side_input_scale"] - bias_scale = test_param["bias_scale"] - padding_type = test_param["padding_type"] - - conv_input, _, _ = gen_array_ops.quantize_v2( - random_ops.random_uniform( - [batch_size, input_channels // 4, input_height, input_width, 4], - minval=-0.0, - maxval=1.0, - dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - - kernel, _, _ = gen_array_ops.quantize_v2( - random_ops.random_uniform( - [ - output_channels, input_channels // 4, filter_height, - filter_width, 4 - ], - minval=-1.0, - maxval=1.0, - dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - - output_height = CalculateConvolvedOutputDim(input_height, filter_height, - vertical_stride, padding_type) - output_width = CalculateConvolvedOutputDim(input_width, filter_width, - horizontal_stride, padding_type) - tf_logging.info("output_height=", output_height, ", output_width=", - output_width) - - side_input, _, _ = gen_array_ops.quantize_v2( - random_ops.random_uniform( - [batch_size, output_channels // 4, output_height, output_width, 4], - minval=0.0, - maxval=1.0, - dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - - biases = random_ops.random_uniform( - [output_channels], - minval=-10 * bias_scale, - maxval=20 * bias_scale, - dtype=dtypes.float32) - - strides = [1, 1, vertical_stride, horizontal_stride] - - actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - conv_input, - kernel, - biases, - strides=strides, - padding=padding_type, - conv_input_scale=conv_input_scale, - side_input_scale=side_input_scale, - side_input=side_input, - activation_mode="Relu" if apply_relu else "None", - data_format="NCHW_VECT_C", - filter_format="OIHW_VECT_I") - expected = SimulateFusedConv2dBiasActivationInt8( - conv_input_scale, conv_input, kernel, padding_type, strides, - side_input_scale, side_input, biases, apply_relu) - with self.test_session(use_gpu=True) as sess: - actual_y, expected_y = sess.run([actual, expected]) - tf_logging.info("actual_y = ", actual_y) - tf_logging.info("expected_y = ", expected_y) - self.assertTrue(np.array_equal(actual_y, expected_y)) +# Instantiate the two test suites from test_base, mixing in test.TestCase as +# the test framework. +class FusedConv2DBiasActivationTest(test_base.FusedConv2DBiasActivationTest, + test.TestCase): + pass - def testFusedConvInt8(self): - if not test.is_gpu_available( - cuda_only=True, min_cuda_compute_capability=(6, 1)): - tf_logging.info("int8 test skipped because not run with --config=cuda or " - "no GPUs with compute capability >= 6.1 are available.") - return - for apply_relu in [True, False]: - for test_param in self._test_params: - self.runTest(test_param, apply_relu) +class FusedConvInt8Tests(test_base.FusedConvInt8Tests, test.TestCase): + pass -if __name__ == "__main__": - for index, (input_size_, filter_size_, output_size_, stride_, - padding_) in enumerate(GetShrunkInceptionShapes()): - setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index), - GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_)) - # TODO(b/35359731) - # Fwd, BckInput, and BackFilter to test that for certain input parameter - # set, winograd nonfused algorithm will be excluded from conv autotune. If - # in such case, winograd nonfused algorithm is added as one option of the - # conv autotune, and cuDNN version is smaller than 7, the following tests - # will fail. - ishape = [1, 400, 400, 1] - fshape = [1, 1, 1, 256] - oshape = [1, 400, 400, 256] - setattr(FusedConv2DBiasActivationTest, - "testInceptionFwd_No_Winograd_Nonfused", - GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)) +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..35fc65e4ba8ff5f38f0024930213468b3dc0bed6 --- /dev/null +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py @@ -0,0 +1,945 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Provides test suites that can be run to test fused convolutions. + +Each of the two test suites in this module, FusedConv2DBiasActivationTest and +FusedConvInt8Tests, should be "instantiated" by declaring a class which inherits +from the FusedConv test and a class that provides the standard test.TestCase +API. + +See e.g. fused_conv2d_bias_activation_op_test.py in this folder. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numpy as np + +from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def _GetShrunkInceptionShapes(shrink=10): + """Iterator for smaller versions of convolution shapes in 2015 Inception. + + Relative to inception, each depth value is `depth // shrink`. + + Args: + shrink: Factor to shrink each depth value by relative to Inception. + + Yields: + Tuple (input_size, filter_size, out_size, stride, padding), the convolution + parameters of Inception layers. + """ + input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [ + 4, 8, 8, 2048 + ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [ + 4, 8, 8, 1760 + ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [ + 4, 17, 17, 192 + ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [ + 4, 17, 17, 192 + ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [ + 4, 17, 17, 192 + ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [ + 4, 17, 17, 160 + ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024], + [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [ + 4, 17, 17, 768 + ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768], + [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [ + 4, 35, 35, 64 + ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [ + 4, 35, 35, 256 + ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [ + 4, 35, 35, 192 + ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]] + filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [ + 1, 1, 2048, 192 + ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384], + [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [ + 1, 1, 1760, 320 + ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [ + 3, 3, 128, 320 + ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [ + 1, 3, 192, 256 + ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [ + 3, 3, 192, 224 + ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [ + 3, 1, 192, 192 + ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [ + 1, 3, 128, 192 + ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [ + 3, 1, 128, 128 + ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [ + 1, 1, 768, 128 + ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [ + 3, 3, 64, 96 + ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64], + [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [ + 1, 1, 192, 64 + ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64, + 64], [1, 1, 24, 64]] + out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [ + 4, 8, 8, 384 + ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [ + 4, 8, 8, 192 + ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [ + 4, 17, 17, 192 + ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [ + 4, 17, 17, 256 + ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [ + 4, 17, 17, 192 + ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [ + 4, 17, 17, 160 + ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [ + 4, 17, 17, 256 + ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [ + 4, 17, 17, 128 + ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [ + 4, 35, 35, 64 + ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96], + [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48], + [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]] + strides = [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1 + ] + # Shrink sizes to make the test faster + for i in input_sizes: + i[3] //= shrink + for f in filter_sizes: + f[2] //= shrink + f[3] //= shrink + for o in out_sizes: + o[3] //= shrink + # pylint: disable=invalid-name + VALID = "VALID" + SAME = "SAME" + # pylint: enable=invalid-name + paddings = [ + SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, + VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, + SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, + SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME, + SAME, SAME, SAME, SAME, VALID, VALID, VALID + ] + for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, + paddings): + yield i, f, o, s, p + + +def _GetTestConfigs(): + """Get all the valid tests configs to run. + + Returns: + all the valid test configs as tuples of data_format and use_gpu. + """ + test_configs = [("NCHW", True), ("NHWC", True)] + return test_configs + + +def _IotaNdF32Constant(dim_sizes): + + def MakeList(dims): + if len(dims) == 1: + return [float(1 + f) for f in range(dims[0])] + return [MakeList(dims[1:]) for _ in range(dims[0])] + + return constant_op.constant(MakeList(dim_sizes), dtype=dtypes.float32) + + +def _GetInceptionFwdTest(input_size, + filter_size, + stride, + padding, + gpu_only=True): + + def Test(self): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping InceptionFwd %s", + (input_size, filter_size, stride, padding)) + return + tf_logging.info("Testing InceptionFwd %s", + (input_size, filter_size, stride, padding)) + self.CompareFwdValues(input_size, filter_size, [stride, stride], padding) + + return Test + + +class FusedConv2DBiasActivationTest(object): + + @contextlib.contextmanager + def test_scope(self): # pylint: disable=invalid-name + """Can be overridden in base classes to provide a test scope.""" + yield + + def _DtypesToTest(self, use_gpu): + return [dtypes.float32] + + def _FilterFormatsToTest(self, use_gpu): + return ["HWIO", "OIHW"] + + def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias, + strides, padding, activation_mode, data_format, + filter_format, dtype): + """Verifies the output values of the convolution function. + + Args: + tensor_in_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in + [kernel_rows, kernel_cols, input_depth, output_depth]. + bias: 1-D bias tensor of length output_depth. + strides: Stride: [col_stride, row_stride] + padding: Padding type. + activation_mode: Activation mode. + data_format: Format of the data tensors. + filter_format: Filter format to use for the fused convolution. + dtype: Data type for inputs and outputs. + Returns: + Symbolic tensor value and reference value that can be used to + execute the computation and verify the results. + """ + input_size = np.prod(tensor_in_sizes) + filter_size = np.prod(filter_in_sizes) + bias_size = filter_in_sizes[-1] # equals to output depth + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, input_size + 1)] + x2 = [f * 1.0 for f in range(1, filter_size + 1)] + # This is to guarantee that there are always negative values after + # bias add so that we can test whether relu works correctly. + x3 = bias + with self.cached_session(use_gpu=True), self.test_scope(): + t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) + t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) + fused_t2 = t2 + if filter_format == "OIHW": + fused_t2 = _HwioToOihw(t2) + t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) + strides = [1] + strides + [1] + if data_format == "NCHW": + t1 = test_util.NHWCToNCHW(t1) + strides = test_util.NHWCToNCHW(strides) + output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + t1, + fused_t2, + t3, + strides=strides, + padding=padding, + data_format=data_format, + filter_format=filter_format, + activation_mode=activation_mode) + ref_conv_output = nn_ops.conv2d( + t1, t2, strides=strides, padding=padding, data_format=data_format) + ref_bias_output = nn_ops.bias_add( + ref_conv_output, t3, data_format=data_format) + ref_output = nn_ops.relu(ref_bias_output) + if data_format == "NCHW": + output = test_util.NCHWToNHWC(output) + ref_output = test_util.NCHWToNHWC(ref_output) + + return output, ref_output + + def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides, + padding): + """Verifies that CPU and GPU produce the same values. + + Args: + tensor_in_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in + [kernel_rows, kernel_cols, input_depth, output_depth]. + conv_strides: [row_stride, col_stride] for the convolution; + padding: Padding type. + """ + x1 = np.random.rand(*tensor_in_sizes).astype(np.float32) + x2 = np.random.rand(*filter_in_sizes).astype(np.float32) + x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) + + def _SetupVal(data_format, use_gpu): + with self.cached_session(use_gpu=use_gpu), self.test_scope(): + t1 = constant_op.constant(x1, shape=tensor_in_sizes) + t2 = constant_op.constant(x2, shape=filter_in_sizes) + t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) + strides = [1] + conv_strides + [1] + if data_format == "NCHW": + t1 = test_util.NHWCToNCHW(t1) + strides = test_util.NHWCToNCHW(strides) + output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + t1, + t2, + t3, + strides=strides, + padding=padding, + data_format=data_format, + activation_mode="Relu") + + if data_format == "NCHW": + output = test_util.NCHWToNHWC(output) + return output + + tensors = [] + for (data_format, use_gpu) in _GetTestConfigs(): + tensors.append(_SetupVal(data_format, use_gpu)) + with self.cached_session() as sess, self.test_scope(): + values = sess.run(tensors) + for i in range(1, len(values)): + self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3) + + def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides, + padding): + tensors = [] + ref_tensors = [] + for (data_format, use_gpu) in _GetTestConfigs(): + for dtype in self._DtypesToTest(use_gpu): + for filter_format in self._FilterFormatsToTest(use_gpu): + result, expected = self._SetupValuesForDevice( + tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", + data_format, filter_format, dtype) + tensors.append(result) + ref_tensors.append(expected) + with self.cached_session() as sess, self.test_scope(): + values = sess.run(tensors) + ref_values = sess.run(ref_tensors) + for i in range(len(tensors)): + conv = tensors[i] + value = values[i] + ref_value = ref_values[i] + tf_logging.info("expected = %s", ref_value) + tf_logging.info("actual = %s", value) + tol = 1e-5 + if value.dtype == np.float16: + tol = 1e-3 + self.assertAllClose( + np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol) + self.assertShapeEqual(value, conv) + + def testConv2D1x1Filter(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D1x1Filter test.") + return + # expected_output = [ + # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0, + # 86.0, 43.0, 165.0, 131.0, 97.0 + # ] + medians = [-45.0, -130.0, -215.0] + self._VerifyValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + bias=medians, + strides=[1, 1], + padding="VALID") + + def testConv2DEmpty(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2DEmpty test.") + return + # expected_output = [] + self._VerifyValues( + tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + bias=[0.0, 0.0, 0.0], + strides=[1, 1], + padding="VALID") + + def testConv2D2x2Filter(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D2x2Filter test.") + return + # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0] + self._VerifyValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + bias=[-2500.0, -2500.0, -2500.0], + strides=[1, 1], + padding="VALID") + + def testConv2D1x2Filter(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D1x2Filter test.") + return + # expected_output = [ + # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0 + # ] + self._VerifyValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 2, 3, 3], + bias=[-500.0, -500.0, -500.0], + strides=[1, 1], + padding="VALID") + + def testConv2D2x2FilterStride2(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D2x2FilterStride2 test.") + return + # expected_output = [0.0, 67.0, 163.0] + self._VerifyValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + bias=[-2300.0, -2300.0, -2300.0], + strides=[2, 2], + padding="VALID") + + def testConv2D2x2FilterStride2Same(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.") + return + # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] + self._VerifyValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + bias=[-2300.0, -1000.0, -1000.0], + strides=[2, 2], + padding="SAME") + + def testConv2D2x2FilterStride1x2(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.") + return + # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0] + self._VerifyValues( + tensor_in_sizes=[1, 3, 6, 1], + filter_in_sizes=[2, 2, 1, 1], + bias=[-90.0], + strides=[1, 2], + padding="VALID") + + def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.") + return + # expected_output = [0, 0, 175, 205] + self._VerifyValues( + tensor_in_sizes=[1, 7, 7, 1], + filter_in_sizes=[2, 2, 1, 1], + bias=[-100.0], + strides=[3, 3], + padding="VALID") + + def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.") + return + # expected = [0, 0, 2, 4] + self._VerifyValues( + tensor_in_sizes=[1, 3, 3, 1], + filter_in_sizes=[1, 1, 1, 1], + bias=[-5.0], + strides=[2, 2], + padding="SAME") + + # expected = [0, 0, 4, 6] + self._VerifyValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[1, 1, 1, 1], + bias=[-5.0], + strides=[2, 2], + padding="SAME") + + # expected = [4, 0, 1, 0] + self._VerifyValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[2, 2, 1, 1], + bias=[-40.0], + strides=[3, 3], + padding="SAME") + + def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.") + return + # expected = [0, 5] + self._VerifyValues( + tensor_in_sizes=[1, 2, 2, 1], + filter_in_sizes=[2, 2, 1, 2], + bias=[-50.0, -55.0], + strides=[1, 1], + padding="VALID") + + # expected = [0, 2, 282, 322] + self._VerifyValues( + tensor_in_sizes=[1, 8, 8, 1], + filter_in_sizes=[2, 2, 1, 1], + bias=[-200.0], + strides=[4, 4], + padding="SAME") + + def testShapeFunctionEdgeCases(self): + # All shapes unknown. + c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.float32), + strides=[1, 1, 1, 1], + padding="SAME", + activation_mode="Relu") + self.assertEqual([None, None, None, None], c1.get_shape().as_list()) + + # Incorrect input shape. + with self.assertRaises(ValueError): + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + array_ops.placeholder(dtypes.float32, shape=[1, 3]), + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.float32), + strides=[1, 1, 1, 1], + padding="SAME", + activation_mode="Relu") + + # Incorrect filter shape. + with self.assertRaises(ValueError): + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.float32, shape=[1, 3]), + array_ops.placeholder(dtypes.float32), + strides=[1, 1, 1, 1], + padding="SAME", + activation_mode="Relu") + + # Depth mismatch. + with self.assertRaises(ValueError): + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]), + array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]), + array_ops.placeholder(dtypes.float32), + strides=[1, 1, 1, 1], + padding="SAME", + activation_mode="Relu") + + def testOpEdgeCases(self, gpu_only=True): + if gpu_only and not test.is_gpu_available(): + tf_logging.info("Skipping OpEdgeCases tests.") + return + with self.cached_session() as sess, self.test_scope(): + # Illegal strides. + with self.assertRaisesRegexp( + errors_impl.UnimplementedError, + ".*strides.*in the batch and depth dimensions"): + sess.run( + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1]), + strides=[2, 1, 1, 1], + padding="SAME", + activation_mode="Relu")) + with self.assertRaisesRegexp( + errors_impl.UnimplementedError, + ".*strides.*in the batch and depth dimensions"): + sess.run( + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1]), + strides=[1, 1, 1, 2], + padding="SAME", + activation_mode="Relu")) + + # Illegal activation mode. + with self.assertRaisesRegexp(ValueError, + "Op passed string 'Tanh' not in:"): + sess.run( + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1, 1, 1, 1]), + _IotaNdF32Constant([1]), + strides=[1, 1, 1, 1], + padding="SAME", + activation_mode="Tanh")) + + # Filter larger than input. + with self.assertRaisesRegexp(ValueError, "Negative dimension size"): + sess.run( + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _IotaNdF32Constant([32, 20, 20, 3]), + _IotaNdF32Constant([20, 21, 3, 2]), + _IotaNdF32Constant([2]), + strides=[1, 1, 1, 1], + padding="VALID", + activation_mode="Relu")) + with self.assertRaisesRegexp(ValueError, "Negative dimension size"): + sess.run( + fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + _IotaNdF32Constant([32, 20, 20, 3]), + _IotaNdF32Constant([21, 20, 3, 2]), + _IotaNdF32Constant([2]), + strides=[1, 1, 1, 1], + padding="VALID", + activation_mode="Relu")) + + +# Add InceptionFwd tests to FusedConv2DBiasActivationTest. +for index, (input_size_, filter_size_, output_size_, stride_, + padding_) in enumerate(_GetShrunkInceptionShapes()): + setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index), + _GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_)) + +# TODO(b/35359731) +# Fwd, BckInput, and BackFilter to test that for certain input parameter +# set, winograd nonfused algorithm will be excluded from conv autotune. If +# in such case, winograd nonfused algorithm is added as one option of the +# conv autotune, and cuDNN version is smaller than 7, the following tests +# will fail. +ishape = [1, 400, 400, 1] +fshape = [1, 1, 1, 256] +oshape = [1, 400, 400, 256] +setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_No_Winograd_Nonfused", + _GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)) + + +def _CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type): + """Calculates the size of an output dimension of a strided convolution. + + Given the sizes of the corresponding dimension of the input and filter shapes, + and the stride and padding_types, calculates the size of the output dimension. + This function can be called separately for each input dimension. + + Args: + input_dim: An `int` specifying the size of the input dimension. + filter_dim: An `int` specifying the size of the filter dimension. + stride: An `int` specifying the step size of the convolution along the + input dimension. + padding_type: either 'VALID' or 'SAME'. + + Returns: + The size of the output dimension. + """ + if padding_type == "VALID": + return (input_dim - filter_dim + stride) // stride + else: # padding_type == 'SAME' + return (input_dim + stride - 1) // stride + + +def _NchwVectCToNchw(in_tensor): + # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] + t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) + n = in_tensor.shape.dims[0].value + c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [n, c, h, w]) + + +def _OihwVectIToHwio(in_tensor): + # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] + t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) + o = in_tensor.shape.dims[0].value + i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [h, w, i, o]) + + +def _NchwToNchwVectC(in_tensor): + n, c, h, w = in_tensor.shape.as_list() + assert c % 4 == 0 + t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) + return array_ops.transpose(t, [0, 1, 3, 4, 2]) + + +def _HwioToOihw(in_tensor): + return array_ops.transpose(in_tensor, [3, 2, 0, 1]) + + +def _SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, + padding, strides, side_input_scale, + side_input, biases, apply_relu): + """Simulates the int8 fused 2-D convolution op using separate float ops. + + The arguments and return values have the same format, meanings and + restrictions as the actual op. + Args: + conv_input_scale: A scalar 'float'. + conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout. + padding: A `string` from: `"SAME", "VALID"`. + strides: A list of `ints`. + side_input_scale: A scalar 'float'. + side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + biases: A `Tensor` of type `float32` in NCHW layout. + apply_relu: A boolean to specify whether to apply "Relu" activation function + that clips outputs to the range [0, 127], or "None" activation that clips + to the range [-128, 127]. + Returns: + A `Tensor` of type `qint8` in NCHW_VECT_C layout. + """ + conv_result = nn_ops.conv2d( + _NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)), + _OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)), + strides=strides, + padding=padding, + data_format="NCHW") * conv_input_scale + + conv_and_side_inputs = conv_result + side_input_scale * _NchwVectCToNchw( + gen_array_ops.dequantize(side_input, -128, 127)) + + output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + if apply_relu: + output = nn_ops.relu(output) + + result, _, _ = gen_array_ops.quantize_v2( + _NchwToNchwVectC(output), -128, 127, dtypes.qint8) + return result + + +# TODO(b/114580749): XLA:CPU/GPU don't support int8 at the moment, so this test +# doesn't currently use XLA. +class FusedConvInt8Tests(object): + _test_params = [ + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 8, + "input_width": 8, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 6, + "input_width": 6, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 16, + "output_channels": 16, + "input_height": 9, + "input_width": 9, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 5, + "filter_width": 5, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 7, + "filter_width": 1, + "vertical_stride": 2, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 1, + "filter_width": 7, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + ] + + @contextlib.contextmanager + def test_scope(self): # pylint: disable=invalid-name + """Can be overridden in base classes to provide a test scope.""" + yield + + def runTest(self, test_param, apply_relu): + batch_size = test_param["batch_size"] + input_channels = test_param["input_channels"] + output_channels = test_param["output_channels"] + input_height = test_param["input_height"] + input_width = test_param["input_width"] + filter_height = test_param["filter_height"] + filter_width = test_param["filter_width"] + vertical_stride = test_param["vertical_stride"] + horizontal_stride = test_param["horizontal_stride"] + conv_input_scale = test_param["conv_input_scale"] + side_input_scale = test_param["side_input_scale"] + bias_scale = test_param["bias_scale"] + padding_type = test_param["padding_type"] + + with self.cached_session(use_gpu=True) as sess, self.test_scope(): + conv_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, input_channels // 4, input_height, input_width, 4], + minval=-0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + kernel, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform([ + output_channels, input_channels // 4, filter_height, filter_width, + 4 + ], + minval=-1.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, + dtypes.qint8) + + output_height = _CalculateConvolvedOutputDim( + input_height, filter_height, vertical_stride, padding_type) + output_width = _CalculateConvolvedOutputDim( + input_width, filter_width, horizontal_stride, padding_type) + tf_logging.info("output_height=%s, output_width=%s", output_height, + output_width) + + side_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform([ + batch_size, output_channels // 4, output_height, output_width, 4 + ], + minval=0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, + dtypes.qint8) + + biases = random_ops.random_uniform([output_channels], + minval=-10 * bias_scale, + maxval=20 * bias_scale, + dtype=dtypes.float32) + + strides = [1, 1, vertical_stride, horizontal_stride] + + actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + conv_input, + kernel, + biases, + strides=strides, + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=side_input, + activation_mode="Relu" if apply_relu else "None", + data_format="NCHW_VECT_C", + filter_format="OIHW_VECT_I") + + expected = _SimulateFusedConv2dBiasActivationInt8( + conv_input_scale, conv_input, kernel, padding_type, strides, + side_input_scale, side_input, biases, apply_relu) + + actual_y, expected_y = sess.run([actual, expected]) + self.assertAllClose(actual_y, expected_y, rtol=0, atol=1) + + def testFusedConvInt8(self): + if not test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=(6, 1)): + tf_logging.info("int8 test skipped because not run with --config=cuda or " + "no GPUs with compute capability >= 6.1 are available.") + return + for apply_relu in [True, False]: + for test_param in self._test_params: + self.runTest(test_param, apply_relu) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index d3897483740faafa62befbaf873886139f1482d2..8bc4db8424f661bba65675f0cd1c2fc33696eda9 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -773,9 +773,9 @@ def mutual_information_penalty( structured_generator_inputs: A list of Tensors representing the random noise that must have high mutual information with the generator output. List length should match `predicted_distributions`. - predicted_distributions: A list of tf.Distributions. Predicted by the - recognizer, and used to evaluate the likelihood of the structured noise. - List length should match `structured_generator_inputs`. + predicted_distributions: A list of `tfp.distributions.Distribution`s. + Predicted by the recognizer, and used to evaluate the likelihood of the + structured noise. List length should match `structured_generator_inputs`. weights: Optional `Tensor` whose rank is either 0, or the same dimensions as `structured_generator_inputs`. scope: The scope for the operations performed in computing the loss. diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index a462b68e28be989eee04fe4ec5ee902d75e5d909..b9ac1bf15138c7e7d15ab3ebdac605d84921b6e5 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -91,9 +91,9 @@ class InfoGANModel( structured_generator_inputs: A list of Tensors representing the random noise that must have high mutual information with the generator output. List length should match `predicted_distributions`. - predicted_distributions: A list of tf.Distributions. Predicted by the - recognizer, and used to evaluate the likelihood of the structured noise. - List length should match `structured_generator_inputs`. + predicted_distributions: A list of `tfp.distributions.Distribution`s. + Predicted by the recognizer, and used to evaluate the likelihood of the + structured noise. List length should match `structured_generator_inputs`. discriminator_and_aux_fn: The original discriminator function that returns a tuple of (logits, `predicted_distributions`). """ diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 58f348034fdcaadd8d738517aef2a7e2f0172c13..64d670619905a427a84bee4b661228abca591fae 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -399,7 +399,7 @@ class StarGANModelTest(test.TestCase): target_tensor = train._generate_stargan_random_domain_target( batch_size, domain_numbers) - with self.test_session() as sess: + with self.cached_session() as sess: targets = sess.run(target_tensor) self.assertTupleEqual((batch_size, domain_numbers), targets.shape) for target in targets: @@ -676,7 +676,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(model_loss, namedtuples.GANLoss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 726f74c7b7addbd6c048d0b05f5695a77deb53b2..3549cedb70a6104ff3d3829d1b94cb5f08c5119c 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -30,19 +29,17 @@ limitations under the License. #include #include "tensorflow/contrib/gdr/gdr.pb.h" -#include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/common_runtime/process_state.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #endif // GOOGLE_CUDA -#include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numa.h" namespace tensorflow { @@ -70,14 +67,11 @@ bool IsGDRAvailable() { int TryToReadNumaNode(ibv_device* device) { #if defined(__APPLE__) LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0"; - return 0; + return port::kNUMANoAffinity; #elif defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. - return 0; + return port::kNUMANoAffinity; #else - VLOG(2) << "Trying to read NUMA node for device: " << device->name; - static const int kUnknownNumaNode = -1; - auto filename = string(device->ibdev_path) + "/device/numa_node"; std::ifstream ifs(filename.c_str()); @@ -91,12 +85,12 @@ int TryToReadNumaNode(ibv_device* device) { << value << "), but there must be at least one NUMA node" ", so returning NUMA node zero"; - return 0; + return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; return value; } - return kUnknownNumaNode; + return port::kNUMANoAffinity; #endif } @@ -148,7 +142,8 @@ class GdrMemoryManager : public RemoteMemoryManager { ibv_mr* FindMemoryRegion(void* addr, size_t length); - void InsertMemoryRegion(void* addr, size_t length); + void InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name); void EvictMemoryRegion(void* addr, size_t length); @@ -158,6 +153,7 @@ class GdrMemoryManager : public RemoteMemoryManager { RdmaEndpointPtr listening_; std::atomic stopped_; int epfd_; + int numa_node_; // Server side endpoints // Accessed sequentially in Run() so not protected by lock @@ -183,26 +179,6 @@ class GdrMemoryManager : public RemoteMemoryManager { TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager); }; -// TODO(byronyi): remove this class and its registration when the default -// cpu_allocator() returns visitable allocator, or cpu_allocator() is no -// longer in use. -class BFCGdrAllocator : public BFCAllocator { - public: - BFCGdrAllocator() - : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, - true, "cpu_gdr_bfc") {} -}; -class BFCGdrAllocatorFactory : public AllocatorFactory { - public: - Allocator* CreateAllocator() override { return new BFCGdrAllocator; } - - virtual SubAllocator* CreateSubAllocator(int numa_node) { - return new BasicCPUAllocator(numa_node); - } -}; - -REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory); - GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) : host_(host), port_(port), @@ -271,45 +247,39 @@ Status GdrMemoryManager::Init() { "cannot add server to epoll"); } - Allocator* allocators[] = { -#if GOOGLE_CUDA - GPUProcessState::singleton()->GetCUDAHostAllocator(0), -#endif // GOOGLE_CUDA - ProcessState::singleton()->GetCPUAllocator(0), - cpu_allocator(), - }; + numa_node_ = TryToReadNumaNode(listening_->verbs->device); - using namespace std::placeholders; - VisitableAllocator::Visitor alloc_visitor = - std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2); - VisitableAllocator::Visitor free_visitor = - std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2); - - std::set instrumented_; - - // Host memory allocators - for (Allocator* allocator : allocators) { - auto* visitable_allocator = dynamic_cast(allocator); - CHECK(visitable_allocator) - << "is not visitable for instrumentation" << allocator->Name(); - // Make sure we don't instrument the same allocator twice - if (instrumented_.find(allocator) == std::end(instrumented_)) { - visitable_allocator->AddAllocVisitor(alloc_visitor); - visitable_allocator->AddFreeVisitor(free_visitor); - instrumented_.insert(allocator); - LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name(); - } - } + SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, + size_t num_bytes) { + VLOG(2) << "Registering RDMA capable memory region on numa_node " + << numa_node; + InsertMemoryRegion(ptr, num_bytes, strings::StrCat("CPU:", numa_node)); + }; + SubAllocator::Visitor free_visitor = [this](void* ptr, int numa_node, + size_t num_bytes) { + VLOG(2) << "De-registering RDMA capable memory region on numa_node " + << numa_node; + EvictMemoryRegion(ptr, num_bytes); + }; + ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor); + ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); + LOG(INFO) << "Instrumenting CPU allocator(s)"; #if GOOGLE_CUDA - VisitableAllocator::Visitor cuda_alloc_visitor = - std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2); if (IsGDRAvailable()) { - // Note we don't free allocated GPU memory so there is no free visitor - int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1; + int bus_id = numa_node_ + 1; + + SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, + size_t num_bytes) { + VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; + InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); + }; GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); - LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; + GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, + alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); + LOG(INFO) << "Instrumenting GPU allocator(s) with bus_id " << bus_id; } #endif // GOOGLE_CUDA @@ -429,7 +399,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( ibv_mr* mr = FindMemoryRegion(addr, length); #if GOOGLE_CUDA - if (!on_host) { + if (device->tensorflow_gpu_device_info() && !on_host) { Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); GPUUtil::CopyGPUTensorToCPU( @@ -480,11 +450,27 @@ void GdrMemoryManager::TransportOptionsFromTensor( #endif if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; + Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); + Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); + + std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); + VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; + + buffer = DMAHelper::buffer(&host_copy); + addr = buffer->data(); + length = buffer->size(); + + mr = FindMemoryRegion(addr, length); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + return; + } + + buffer->Ref(); + } else { + buffer->Ref(); } - buffer->Ref(); TensorKey tensor_key = next_key_++; { mutex_lock l(server_mu_); @@ -494,7 +480,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( uint64_t checksum = 0; if (VLOG_IS_ON(2)) { #ifdef GOOGLE_CUDA - if (!on_host) { + if (device->tensorflow_gpu_device_info() && !on_host) { checksum = GPUUtil::Checksum(device, device_context, tensor); } else { checksum = GPUUtil::Checksum(tensor); @@ -532,7 +518,8 @@ void GdrMemoryManager::TensorFromTransportOptions( Tensor host_copy; #if GOOGLE_CUDA if (mr == nullptr && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = + GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); buffer = DMAHelper::buffer(&host_copy); addr = buffer->data(); @@ -542,8 +529,18 @@ void GdrMemoryManager::TensorFromTransportOptions( #endif // GOOGLE_CUDA if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; + Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); + host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); + + buffer = DMAHelper::buffer(&host_copy); + addr = buffer->data(); + length = buffer->size(); + + mr = FindMemoryRegion(addr, length); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + return; + } } decltype(clients_)::iterator iter; @@ -592,7 +589,8 @@ void GdrMemoryManager::TensorFromTransportOptions( } #if GOOGLE_CUDA - if (host_copy.NumElements() > 0) { + if (device->tensorflow_gpu_device_info() && !on_host && + host_copy.NumElements() > 0) { uint64_t checksum = 0; if (VLOG_IS_ON(2)) { checksum = GPUUtil::Checksum(host_copy); @@ -622,6 +620,12 @@ void GdrMemoryManager::TensorFromTransportOptions( } #endif // GOOGLE_CUDA + if ((on_host || !device->tensorflow_gpu_device_info()) && + host_copy.NumElements() > 0) { + std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); + VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; + } + uint64_t end = Env::Default()->NowMicros(); VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() @@ -631,7 +635,7 @@ void GdrMemoryManager::TensorFromTransportOptions( uint64_t checksum = 0; if (VLOG_IS_ON(2)) { #ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && (!on_host)) { + if (device->tensorflow_gpu_device_info() && !on_host) { checksum = GPUUtil::Checksum(device, device_context, *tensor); } else { checksum = GPUUtil::Checksum(*tensor); @@ -692,7 +696,8 @@ ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { } } -void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) { +void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length, + const std::string& allocator_name) { if (length == 0) return; ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length); if (mr != nullptr) { @@ -700,7 +705,8 @@ void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) { auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); mrs_.insert(iter, {mr, &MRDeleter}); } else { - LOG(WARNING) << "Cannot register memory region"; + LOG(WARNING) << "Cannot register memory region allocated by " + << allocator_name; } } diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 97f38c923f4a19cedf3e16203ca1e66b7e5e45d2..0ebcdc26889bf8a574b1967cd8a6b1c466817127 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -214,7 +214,7 @@ class TransformTest(test.TestCase): def test_graph_replace_gradients(self): ops.reset_default_graph() - w = variables.Variable(0.0, name="w") + w = variables.VariableV1(0.0, name="w") y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2") g = gradients_impl.gradients(y, w, name="grad")[0] diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index fed8a771cc153c051e31e86dbd7885cbc3271f4c..27aed091c249caa6e50748419a93f3579e6632a4 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase): ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]]))) def testGrid2LSTMCellWithRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid2BasicRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) @@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase): [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellTied(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) @@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase): [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellWithRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid1LSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope: x = array_ops.zeros([1, 3]) @@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid3LSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase): """ def testGridRNNEdgeCasesLikeRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([3, 2]) @@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase): self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],)) def testGridRNNEdgeCasesNoOutput(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape(), (3, num_units)) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((3, input_size)) @@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase): def testGrid2LSTMCellLegacy(self): """Test for legacy case (when state_is_tuple=False).""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index d796e43d877e463fa4398741748013b2eb661155..f7f1189bb93c611719186a697c40f371644f63a2 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -51,7 +51,7 @@ class SequenceFileDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_repeats): # Dataset is repeated. for i in range(25): # 25 records. diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index 6e0e628655fbc32a43fad2dc4883b26c6ad57c48..bf398b838dfaaff6fdaf33a6cd7086ef13e43a3e 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -19,14 +19,14 @@ from __future__ import print_function from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import -from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -class SequenceFileDataset(Dataset): +class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" def __init__(self, filenames): diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9393b702d11a2ef84586f712d30c26fe2a8972bb --- /dev/null +++ b/tensorflow/contrib/ignite/BUILD @@ -0,0 +1,139 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "if_not_windows", + "if_windows", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +py_library( + name = "ignite", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/ignite_dataset_ops.cc", + "kernels/ignite_client.h", + "kernels/ignite_byte_swapper.h", + "kernels/ignite_plain_client.h", + "kernels/ignite_ssl_wrapper.h", + "kernels/ignite_ssl_wrapper.cc", + "kernels/ignite_binary_object_parser.h", + "kernels/ignite_binary_object_parser.cc", + "kernels/ignite_dataset.h", + "kernels/ignite_dataset.cc", + "kernels/ignite_dataset_iterator.h", + "kernels/ignite_dataset_iterator.cc", + ] + if_not_windows([ + "kernels/ignite_plain_client_unix.cc", + ]) + if_windows([ + "kernels/ignite_plain_client_windows.cc", + ]), + copts = if_windows([ + "-DWIN32_LEAN_AND_MEAN", + ]), + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@boringssl//:ssl", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/ignite_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":ignite_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "ignite_op_loader", + srcs = ["python/ops/ignite_op_loader.py"], + dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/ignite:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +# The Apache Ignite servers have to setup before the test and tear down +# after the test manually. The docker engine has to be installed. +# +# To setup Apache Ignite servers: +# $ bash ./python/tests/start_ignite.sh +# +# To tear down Apache Ignite servers: +# $ bash ./python/tests/stop_ignite.sh +tf_py_test( + name = "ignite_dataset_test", + srcs = ["python/tests/ignite_dataset_test.py"], + additional_deps = [ + ":ignite", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "no_windows", + "notap", + ], +) diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..55c89d27996318dabb29bb15372411005301ebd9 --- /dev/null +++ b/tensorflow/contrib/ignite/README.md @@ -0,0 +1,167 @@ +# Ignite Dataset + +- [Overview](#overview) +- [Features](#features) + * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) + * [Structured Objects](#structured-objects) + * [Distributed Training](#distributed-training) + * [SSL Connection](#ssl-connection) + * [Windows Support](#windows-support) +- [Try it out](#try-it-out) +- [Limitations](#limitations) + +## Overview + +[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for +transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow. + +## Features + +Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below. + +### Distributed In-Memory Datasource +[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize +these benefits of Apache Ignite by using Ignite Dataset. Moreover, Ignite Dataset can be used for the following use-cases: +- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations. +- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations. +- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow. + +Note that Apache Ignite is not just a step of ETL pipeline between a database or a data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. By choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, at the same time, an ability to use this data for neural network training and inference. + +```bash +$ apache-ignite-fabric/bin/ignite.sh +$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/" + +jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY'); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY'); +jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR'); +``` + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> for _ in range(3): +>>> print(sess.run(next_obj)) + +{'key': 1, 'val': {'NAME': b'WARM KITTY'}} +{'key': 2, 'val': {'NAME': b'SOFT KITTY'}} +{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}} +``` + +### Structured Objects +[Apache Ignite](https://ignite.apache.org/) allows to store any type of objects. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES") +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> print(sess.run(next_obj)) + +{ + 'key': 'kitten.png', + 'val': { + 'metadata': { + 'file_name': b'kitten.png', + 'label': b'little ball of fur', + width: 800, + height: 600 + }, + 'pixels': [0, 0, 0, 0, ..., 0] + } +} +``` + Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) +>>> iterator = dataset.make_one_shot_iterator() +>>> next_obj = iterator.get_next() +>>> +>>> with tf.Session() as sess: +>>> print(sess.run(next_obj)) + +[0, 0, 0, 0, ..., 0] +``` + +### Distributed Training + +TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is the ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset. + + + +Using this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottlenecks. + +Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. + +Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset("IMAGES") +>>> +>>> # Compute gradients locally on every worker node. +>>> gradients = [] +>>> for i in range(5): +>>> with tf.device("/job:WORKER/task:%d" % i): +>>> device_iterator = dataset.make_one_shot_iterator() +>>> device_next_obj = device_iterator.get_next() +>>> gradient = compute_gradient(device_next_obj) +>>> gradients.append(gradient) +>>> +>>> # Aggregate them on master node. +>>> result_gradient = tf.reduce_sum(gradients) +>>> +>>> with tf.Session("grpc://localhost:10000") as sess: +>>> print(sess.run(result_gradient)) +``` + +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. + +### SSL Connection + +Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. + +```python +>>> import tensorflow as tf +>>> from tensorflow.contrib.ignite import IgniteDataset +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite") +>>> ... +``` + +### Windows Support + +Ignite Dataset is fully compatible with Windows. You can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems. + +## Try it out + +The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: + +``` +docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist +``` + +After that you will be able to work with it following way: + +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") + +## Limitations + +Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure. diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f42947696f76e168f77b2316758209f1f71a7915 --- /dev/null +++ b/tensorflow/contrib/ignite/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""IgniteDataset that allows to get data from Apache Ignite. + +Apache Ignite is a memory-centric distributed database, caching, and +processing platform for transactional, analytical, and streaming workloads, +delivering in-memory speeds at petabyte scale. This contrib package +contains an integration between Apache Ignite and TensorFlow. The +integration is based on tf.data from TensorFlow side and Binary Client +Protocol from Apache Ignite side. It allows to use Apache Ignite as a +datasource for neural network training, inference and all other +computations supported by TensorFlow. Ignite Dataset is based on Apache +Ignite Binary Client Protocol: +https://apacheignite.readme.io/v2.6/docs/binary-client-protocol. + +@@IgniteDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops import IgniteDataset +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "IgniteDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c8a7d44b07b43f788bcbc0850b5162cc14dd951 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc @@ -0,0 +1,334 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {} + +Status BinaryObjectParser::Parse(uint8_t** ptr, + std::vector* out_tensors, + std::vector* types) const { + uint8_t object_type_id = ParseByte(ptr); + + // Skip non-leaf nodes. + if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ) + types->push_back(object_type_id); + + switch (object_type_id) { + case BYTE: { + out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({})); + out_tensors->back().scalar()() = ParseByte(ptr); + break; + } + case SHORT: { + out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({})); + out_tensors->back().scalar()() = ParseShort(ptr); + break; + } + case USHORT: { + out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({})); + out_tensors->back().scalar()() = ParseUnsignedShort(ptr); + break; + } + case INT: { + out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({})); + out_tensors->back().scalar()() = ParseInt(ptr); + break; + } + case LONG: { + out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); + out_tensors->back().scalar()() = ParseLong(ptr); + break; + } + case FLOAT: { + out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({})); + out_tensors->back().scalar()() = ParseFloat(ptr); + break; + } + case DOUBLE: { + out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({})); + out_tensors->back().scalar()() = ParseDouble(ptr); + break; + } + case BOOL: { + out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({})); + out_tensors->back().scalar()() = ParseBool(ptr); + break; + } + case STRING: { + out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({})); + out_tensors->back().scalar()() = ParseString(ptr); + break; + } + case DATE: { + out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); + out_tensors->back().scalar()() = ParseLong(ptr); + break; + } + case BYTE_ARR: { + int32_t length = ParseInt(ptr); + uint8_t* arr = ParseByteArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_UINT8, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case SHORT_ARR: { + int32_t length = ParseInt(ptr); + int16_t* arr = ParseShortArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT16, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case USHORT_ARR: { + int32_t length = ParseInt(ptr); + uint16_t* arr = ParseUnsignedShortArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_UINT16, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case INT_ARR: { + int32_t length = ParseInt(ptr); + int32_t* arr = ParseIntArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT32, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case LONG_ARR: { + int32_t length = ParseInt(ptr); + int64_t* arr = ParseLongArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT64, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case FLOAT_ARR: { + int32_t length = ParseInt(ptr); + float* arr = ParseFloatArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case DOUBLE_ARR: { + int32_t length = ParseInt(ptr); + double* arr = ParseDoubleArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case BOOL_ARR: { + int32_t length = ParseInt(ptr); + bool* arr = ParseBoolArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_BOOL, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case STRING_ARR: { + int32_t length = ParseInt(ptr); + out_tensors->emplace_back(cpu_allocator(), DT_STRING, + TensorShape({length})); + for (int32_t i = 0; i < length; i++) + out_tensors->back().vec()(i) = ParseString(ptr); + break; + } + case DATE_ARR: { + int32_t length = ParseInt(ptr); + int64_t* arr = ParseLongArr(ptr, length); + out_tensors->emplace_back(cpu_allocator(), DT_INT64, + TensorShape({length})); + std::copy_n(arr, length, out_tensors->back().flat().data()); + break; + } + case WRAPPED_OBJ: { + int32_t byte_arr_size = ParseInt(ptr); + TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); + int32_t offset = ParseInt(ptr); + + break; + } + case COMPLEX_OBJ: { + uint8_t version = ParseByte(ptr); + int16_t flags = ParseShort(ptr); + int32_t type_id = ParseInt(ptr); + int32_t hash_code = ParseInt(ptr); + int32_t length = ParseInt(ptr); + int32_t schema_id = ParseInt(ptr); + int32_t schema_offset = ParseInt(ptr); + + // 24 is size of header just read. + uint8_t* end = *ptr + schema_offset - 24; + int32_t i = 0; + while (*ptr < end) { + i++; + TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); + } + + *ptr += (length - schema_offset); + + break; + } + default: { + return errors::Unknown("Unknowd binary type (type id ", + (int)object_type_id, ")"); + } + } + + return Status::OK(); +} + +uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const { + uint8_t res = **ptr; + *ptr += 1; + + return res; +} + +int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const { + int16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt16(res); + *ptr += 2; + + return *res; +} + +uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const { + uint16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredUnsignedInt16(res); + *ptr += 2; + + return *res; +} + +int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const { + int32_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt32(res); + *ptr += 4; + + return *res; +} + +int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const { + int64_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt64(res); + *ptr += 8; + + return *res; +} + +float BinaryObjectParser::ParseFloat(uint8_t** ptr) const { + float* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredFloat(res); + *ptr += 4; + + return *res; +} + +double BinaryObjectParser::ParseDouble(uint8_t** ptr) const { + double* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredDouble(res); + *ptr += 8; + + return *res; +} + +bool BinaryObjectParser::ParseBool(uint8_t** ptr) const { + bool res = **reinterpret_cast(ptr); + *ptr += 1; + + return res; +} + +string BinaryObjectParser::ParseString(uint8_t** ptr) const { + int32_t length = ParseInt(ptr); + string res(*reinterpret_cast(ptr), length); + *ptr += length; + + return res; +} + +uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const { + uint8_t* res = *reinterpret_cast(ptr); + *ptr += length; + + return res; +} + +int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const { + int16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt16Arr(res, length); + *ptr += length * 2; + + return res; +} + +uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr, + int length) const { + uint16_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length); + *ptr += length * 2; + + return res; +} + +int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const { + int32_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt32Arr(res, length); + *ptr += length * 4; + + return res; +} + +int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const { + int64_t* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredInt64Arr(res, length); + *ptr += length * 8; + + return res; +} + +float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const { + float* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredFloatArr(res, length); + *ptr += length * 4; + + return res; +} + +double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const { + double* res = *reinterpret_cast(ptr); + byte_swapper_.SwapIfRequiredDoubleArr(res, length); + *ptr += length * 8; + + return res; +} + +bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const { + bool* res = *reinterpret_cast(ptr); + *ptr += length; + + return res; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1f856643a790de6acaa82d4b8ad894fd364376 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ + +#include +#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class BinaryObjectParser { + public: + BinaryObjectParser(); + Status Parse(uint8_t** ptr, std::vector* out_tensors, + std::vector* types) const; + + private: + uint8_t ParseByte(uint8_t** ptr) const; + int16_t ParseShort(uint8_t** ptr) const; + uint16_t ParseUnsignedShort(uint8_t** ptr) const; + int32_t ParseInt(uint8_t** ptr) const; + int64_t ParseLong(uint8_t** ptr) const; + float ParseFloat(uint8_t** ptr) const; + double ParseDouble(uint8_t** ptr) const; + bool ParseBool(uint8_t** ptr) const; + string ParseString(uint8_t** ptr) const; + uint8_t* ParseByteArr(uint8_t** ptr, int length) const; + int16_t* ParseShortArr(uint8_t** ptr, int length) const; + uint16_t* ParseUnsignedShortArr(uint8_t** ptr, int length) const; + int32_t* ParseIntArr(uint8_t** ptr, int length) const; + int64_t* ParseLongArr(uint8_t** ptr, int length) const; + float* ParseFloatArr(uint8_t** ptr, int length) const; + double* ParseDoubleArr(uint8_t** ptr, int length) const; + bool* ParseBoolArr(uint8_t** ptr, int length) const; + + const ByteSwapper byte_swapper_; +}; + +enum ObjectType { + BYTE = 1, + SHORT = 2, + INT = 3, + LONG = 4, + FLOAT = 5, + DOUBLE = 6, + USHORT = 7, + BOOL = 8, + STRING = 9, + DATE = 11, + BYTE_ARR = 12, + SHORT_ARR = 13, + INT_ARR = 14, + LONG_ARR = 15, + FLOAT_ARR = 16, + DOUBLE_ARR = 17, + USHORT_ARR = 18, + BOOL_ARR = 19, + STRING_ARR = 20, + DATE_ARR = 22, + WRAPPED_OBJ = 27, + COMPLEX_OBJ = 103 +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h new file mode 100644 index 0000000000000000000000000000000000000000..46df3e39dc4ec6dd4ef5730a184264eaa9fc5872 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ + +#include +#include "tensorflow/core/platform/byte_order.h" + +namespace tensorflow { + +class ByteSwapper { + public: + ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; } + + inline void SwapIfRequiredInt16(int16_t *x) const { + if (swap_) { + Swap16(x); + } + } + + inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const { + if (swap_) { + Swap16(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt32(int32_t *x) const { + if (swap_) { + Swap32(x); + } + } + + inline void SwapIfRequiredFloat(float *x) const { + if (swap_) { + Swap32(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt64(int64_t *x) const { + if (swap_) { + Swap64(x); + } + } + + inline void SwapIfRequiredDouble(double *x) const { + if (swap_) { + Swap64(reinterpret_cast(x)); + } + } + + inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap16(&x[i]); + } + } + + inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, + int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap16(reinterpret_cast(&x[i])); + } + } + + inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap32(&x[i]); + } + } + + inline void SwapIfRequiredFloatArr(float *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap32(reinterpret_cast(&x[i])); + } + } + + inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) Swap64(&x[i]); + } + } + + inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const { + if (swap_) { + for (int32_t i = 0; i < length; i++) + Swap64(reinterpret_cast(&x[i])); + } + } + + private: + inline void Swap16(int16_t *x) const { + *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF); + } + + inline void Swap32(int32_t *x) const { + *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) | + (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF); + } + + inline void Swap64(int64_t *x) const { + *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) | + (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) | + (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) | + (((*x >> 48) & 0xFF) << 8) | ((*x >> 56) & 0xFF); + } + + bool swap_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h new file mode 100644 index 0000000000000000000000000000000000000000..459b50b48fd95ad105bccaca4076160e0ef152ee --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_client.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Client { + public: + Client(bool big_endian) : byte_swapper_(ByteSwapper(big_endian)) {} + virtual Status Connect() = 0; + virtual Status Disconnect() = 0; + virtual bool IsConnected() = 0; + virtual int GetSocketDescriptor() = 0; + virtual Status ReadData(uint8_t *buf, const int32_t length) = 0; + virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0; + + inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); } + + inline Status ReadShort(int16_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2)); + byte_swapper_.SwapIfRequiredInt16(data); + + return Status::OK(); + } + + inline Status ReadInt(int32_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4)); + byte_swapper_.SwapIfRequiredInt32(data); + + return Status::OK(); + } + + inline Status ReadLong(int64_t *data) { + TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8)); + byte_swapper_.SwapIfRequiredInt64(data); + + return Status::OK(); + } + + inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } + + inline Status WriteShort(const int16_t data) { + int16_t tmp = data; + byte_swapper_.SwapIfRequiredInt16(&tmp); + return WriteData((uint8_t *)&tmp, 2); + } + + inline Status WriteInt(const int32_t data) { + int32_t tmp = data; + byte_swapper_.SwapIfRequiredInt32(&tmp); + return WriteData((uint8_t *)&tmp, 4); + } + + inline Status WriteLong(const int64_t data) { + int64_t tmp = data; + byte_swapper_.SwapIfRequiredInt64(&tmp); + return WriteData((uint8_t *)&tmp, 8); + } + + private: + const ByteSwapper byte_swapper_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4a7d3c513a796c9d95b371bedc609fd75188817 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +IgniteDataset::IgniteDataset(OpKernelContext* ctx, string cache_name, + string host, int32 port, bool local, int32 part, + int32 page_size, string username, string password, + string certfile, string keyfile, + string cert_password, std::vector schema, + std::vector permutation, + DataTypeVector dtypes, + std::vector shapes) + : DatasetBase(DatasetContext(ctx)), + cache_name_(std::move(cache_name)), + host_(std::move(host)), + port_(port), + local_(local), + part_(part), + page_size_(page_size), + username_(std::move(username)), + password_(std::move(password)), + certfile_(std::move(certfile)), + keyfile_(std::move(keyfile)), + cert_password_(std::move(cert_password)), + schema_(std::move(schema)), + permutation_(std::move(permutation)), + dtypes_(dtypes), + shapes_(shapes) { + LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name_ + << "', host='" << host_ << "', port=" << port_ + << ", local=" << local_ << ", part=" << part_ + << ", page_size=" << page_size_ << ", username='" << username_ + << "', certfile='" << certfile_ << "', keyfile='" + << keyfile_ + "']"; +} + +IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; } + +std::unique_ptr IgniteDataset::MakeIteratorInternal( + const string& prefix) const { + return std::unique_ptr(new IgniteDatasetIterator( + {this, strings::StrCat(prefix, "::Ignite")}, std::move(this->host_), + this->port_, std::move(this->cache_name_), this->local_, this->part_, + this->page_size_, std::move(this->username_), std::move(this->password_), + std::move(this->certfile_), std::move(this->keyfile_), + std::move(this->cert_password_), std::move(this->schema_), + std::move(this->permutation_))); +} + +const DataTypeVector& IgniteDataset::output_dtypes() const { return dtypes_; } + +const std::vector& IgniteDataset::output_shapes() const { + return shapes_; +} + +string IgniteDataset::DebugString() const { return "IgniteDatasetOp::Dataset"; } + +Status IgniteDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { + return errors::Unimplemented( + "IgniteDataset does not support 'AsGraphDefInternal'"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..66bfdf2e2a168e59cd2fec8e2ac5b8fd482d5c15 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { + +class IgniteDataset : public DatasetBase { + public: + IgniteDataset(OpKernelContext* ctx, string cache_name, string host, + int32 port, bool local, int32 part, int32 page_size, + string username, string password, string certfile, + string keyfile, string cert_password, std::vector schema, + std::vector permutation, DataTypeVector dtypes, + std::vector shapes); + ~IgniteDataset(); + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + const DataTypeVector& output_dtypes() const override; + const std::vector& output_shapes() const override; + string DebugString() const override; + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + const string cache_name_; + const string host_; + const int32 port_; + const bool local_; + const int32 part_; + const int32 page_size_; + const string username_; + const string password_; + const string certfile_; + const string keyfile_; + const string cert_password_; + const std::vector schema_; + const std::vector permutation_; + const DataTypeVector dtypes_; + const std::vector shapes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc new file mode 100644 index 0000000000000000000000000000000000000000..5da9127aa6a3a4bc16347e6890cc1ba44406c0d5 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc @@ -0,0 +1,422 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +IgniteDatasetIterator::IgniteDatasetIterator( + const Params& params, string host, int32 port, string cache_name, + bool local, int32 part, int32 page_size, string username, string password, + string certfile, string keyfile, string cert_password, + std::vector schema, std::vector permutation) + : DatasetIterator(params), + cache_name_(std::move(cache_name)), + local_(local), + part_(part), + page_size_(page_size), + username_(std::move(username)), + password_(std::move(password)), + schema_(std::move(schema)), + permutation_(std::move(permutation)), + remainder_(-1), + cursor_id_(-1), + last_page_(false), + valid_state_(true) { + Client* p_client = new PlainClient(std::move(host), port, false); + + if (certfile.empty()) + client_ = std::unique_ptr(p_client); + else + client_ = std::unique_ptr( + new SslWrapper(std::unique_ptr(p_client), std::move(certfile), + std::move(keyfile), std::move(cert_password), false)); + + LOG(INFO) << "Ignite Dataset Iterator created"; +} + +IgniteDatasetIterator::~IgniteDatasetIterator() { + Status status = CloseConnection(); + if (!status.ok()) LOG(ERROR) << status.ToString(); + + LOG(INFO) << "Ignite Dataset Iterator destroyed"; +} + +Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) { + mutex_lock l(mutex_); + + if (valid_state_) { + Status status = + GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence); + + if (!status.ok()) valid_state_ = false; + + return status; + } + + return errors::Unknown("Iterator is invalid"); +} + +Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) { + return errors::Unimplemented( + "Iterator for IgniteDataset does not support 'SaveInternal'"); +} + +Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) { + return errors::Unimplemented( + "Iterator for IgniteDataset does not support 'RestoreInternal')"); +} + +Status IgniteDatasetIterator::GetNextInternalWithValidState( + IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) { + if (remainder_ == 0 && last_page_) { + cursor_id_ = -1; + *end_of_sequence = true; + + return Status::OK(); + } else { + TF_RETURN_IF_ERROR(EstablishConnection()); + + if (remainder_ == -1) { + TF_RETURN_IF_ERROR(ScanQuery()); + } else if (remainder_ == 0) { + TF_RETURN_IF_ERROR(LoadNextPage()); + } + + uint8_t* initial_ptr = ptr_; + std::vector tensors; + std::vector types; + + TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key + TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val + + remainder_ -= (ptr_ - initial_ptr); + + TF_RETURN_IF_ERROR(CheckTypes(types)); + + for (size_t i = 0; i < tensors.size(); i++) + out_tensors->push_back(tensors[permutation_[i]]); + + *end_of_sequence = false; + + return Status::OK(); + } + + *end_of_sequence = true; + + return Status::OK(); +} + +Status IgniteDatasetIterator::EstablishConnection() { + if (!client_->IsConnected()) { + TF_RETURN_IF_ERROR(client_->Connect()); + + Status status = Handshake(); + if (!status.ok()) { + Status disconnect_status = client_->Disconnect(); + if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString(); + + return status; + } + } + + return Status::OK(); +} + +Status IgniteDatasetIterator::CloseConnection() { + if (cursor_id_ != -1 && !last_page_) { + TF_RETURN_IF_ERROR(EstablishConnection()); + + TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID + + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + if (res_len < kMinResLength) + return errors::Unknown("Close Resource Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Close Resource Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Close Resource Error [status=", status, "]"); + } + + cursor_id_ = -1; + + return client_->Disconnect(); + } else { + LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed"; + } + + return client_->IsConnected() ? client_->Disconnect() : Status::OK(); +} + +Status IgniteDatasetIterator::Handshake() { + int32_t msg_len = kHandshakeReqDefaultLength; + + if (username_.empty()) + msg_len += 1; + else + msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length. + + if (password_.empty()) + msg_len += 1; + else + msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length. + + TF_RETURN_IF_ERROR(client_->WriteInt(msg_len)); + TF_RETURN_IF_ERROR(client_->WriteByte(1)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion)); + TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion)); + TF_RETURN_IF_ERROR(client_->WriteByte(2)); + if (username_.empty()) { + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); + } else { + TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal)); + TF_RETURN_IF_ERROR(client_->WriteInt(username_.length())); + TF_RETURN_IF_ERROR( + client_->WriteData(reinterpret_cast(username_.c_str()), + username_.length())); + } + + if (password_.empty()) { + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); + } else { + TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal)); + TF_RETURN_IF_ERROR(client_->WriteInt(password_.length())); + TF_RETURN_IF_ERROR( + client_->WriteData(reinterpret_cast(password_.c_str()), + password_.length())); + } + + int32_t handshake_res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len)); + uint8_t handshake_res; + TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res)); + + if (handshake_res != 1) { + int16_t serv_ver_major; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major)); + int16_t serv_ver_minor; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor)); + int16_t serv_ver_patch; + TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch)); + uint8_t header; + TF_RETURN_IF_ERROR(client_->ReadByte(&header)); + + if (header == kStringVal) { + int32_t length; + TF_RETURN_IF_ERROR(client_->ReadInt(&length)); + + uint8_t* err_msg_c = new uint8_t[length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length)); + string err_msg(reinterpret_cast(err_msg_c), length); + + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, ", message='", err_msg, "']"); + } else if (header == kNullVal) { + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, "]"); + } else { + return errors::Unknown("Handshake Error [result=", handshake_res, + ", version=", serv_ver_major, ".", serv_ver_minor, + ".", serv_ver_patch, "]"); + } + } + + return Status::OK(); +} + +Status IgniteDatasetIterator::ScanQuery() { + TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR( + client_->WriteInt(JavaHashCode(cache_name_))); // Cache name + TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags + TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object + TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size + TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query + TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag + + uint64 wait_start = Env::Default()->NowMicros(); + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + int64_t wait_stop = Env::Default()->NowMicros(); + + LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms"; + + if (res_len < kMinResLength) + return errors::Unknown("Scan Query Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Scan Query Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Scan Query Error [status=", status, "]"); + } + + TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_)); + + int32_t row_cnt; + TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt)); + + int32_t page_size = res_len - kScanQueryResHeaderLength; + + return ReceivePage(page_size); +} + +Status IgniteDatasetIterator::LoadNextPage() { + TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength)); + TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode)); + TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID + TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID + + uint64 wait_start = Env::Default()->NowMicros(); + int32_t res_len; + TF_RETURN_IF_ERROR(client_->ReadInt(&res_len)); + uint64 wait_stop = Env::Default()->NowMicros(); + + LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000 + << " ms"; + + if (res_len < kMinResLength) + return errors::Unknown("Load Next Page Response is corrupted"); + + int64_t req_id; + TF_RETURN_IF_ERROR(client_->ReadLong(&req_id)); + + int32_t status; + TF_RETURN_IF_ERROR(client_->ReadInt(&status)); + + if (status != 0) { + uint8_t err_msg_header; + TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header)); + + if (err_msg_header == kStringVal) { + int32_t err_msg_length; + TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length)); + + uint8_t* err_msg_c = new uint8_t[err_msg_length]; + auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; }); + TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length)); + string err_msg(reinterpret_cast(err_msg_c), err_msg_length); + + return errors::Unknown("Load Next Page Error [status=", status, + ", message=", err_msg, "]"); + } + return errors::Unknown("Load Next Page Error [status=", status, "]"); + } + + int32_t row_cnt; + TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt)); + + int32_t page_size = res_len - kLoadNextPageResHeaderLength; + + return ReceivePage(page_size); +} + +Status IgniteDatasetIterator::ReceivePage(int32_t page_size) { + remainder_ = page_size; + page_ = std::unique_ptr(new uint8_t[remainder_]); + ptr_ = page_.get(); + + uint64 start = Env::Default()->NowMicros(); + TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_)); + uint64 stop = Env::Default()->NowMicros(); + + double size_in_mb = 1.0 * remainder_ / 1024 / 1024; + double time_in_s = 1.0 * (stop - start) / 1000 / 1000; + LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000 + << " ms download speed " << size_in_mb / time_in_s << " Mb/sec"; + + uint8_t last_page_b; + TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b)); + + last_page_ = !last_page_b; + + return Status::OK(); +} + +Status IgniteDatasetIterator::CheckTypes(const std::vector& types) { + if (schema_.size() != types.size()) + return errors::Unknown("Object has unexpected schema"); + + for (size_t i = 0; i < schema_.size(); i++) { + if (schema_[i] != types[permutation_[i]]) + return errors::Unknown("Object has unexpected schema"); + } + + return Status::OK(); +} + +int32_t IgniteDatasetIterator::JavaHashCode(string str) const { + int32_t h = 0; + for (char& c : str) { + h = 31 * h + c; + } + return h; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..c499e2c9ccfac5c15db08c8fd8b26c37aa0404f3 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class IgniteDatasetIterator : public DatasetIterator { + public: + IgniteDatasetIterator(const Params& params, string host, int32 port, + string cache_name, bool local, int32 part, + int32 page_size, string username, string password, + string certfile, string keyfile, string cert_password, + std::vector schema, + std::vector permutation); + ~IgniteDatasetIterator(); + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override; + + protected: + Status SaveInternal(IteratorStateWriter* writer) override; + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override; + + private: + Status GetNextInternalWithValidState(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence); + + Status EstablishConnection(); + Status CloseConnection(); + Status Handshake(); + Status ScanQuery(); + Status LoadNextPage(); + Status ReceivePage(int32_t page_size); + Status CheckTypes(const std::vector& types); + int32_t JavaHashCode(string str) const; + + std::unique_ptr client_; + BinaryObjectParser parser_; + + const string cache_name_; + const bool local_; + const int32 part_; + const int32 page_size_; + const string username_; + const string password_; + const std::vector schema_; + const std::vector permutation_; + + int32_t remainder_; + int64_t cursor_id_; + bool last_page_; + + bool valid_state_; + + mutex mutex_; + + std::unique_ptr page_; + uint8_t* ptr_; +}; + +constexpr uint8_t kNullVal = 101; +constexpr uint8_t kStringVal = 9; +constexpr uint8_t kProtocolMajorVersion = 1; +constexpr uint8_t kProtocolMinorVersion = 1; +constexpr uint8_t kProtocolPatchVersion = 0; +constexpr int16_t kScanQueryOpcode = 2000; +constexpr int16_t kLoadNextPageOpcode = 2001; +constexpr int16_t kCloseConnectionOpcode = 0; +constexpr int32_t kScanQueryReqLength = 25; +constexpr int32_t kScanQueryResHeaderLength = 25; +constexpr int32_t kLoadNextPageReqLength = 18; +constexpr int32_t kLoadNextPageResHeaderLength = 17; +constexpr int32_t kCloseConnectionReqLength = 18; +constexpr int32_t kHandshakeReqDefaultLength = 8; +constexpr int32_t kMinResLength = 12; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f75b1c5ff55ca9ee493148ff79c2edd4b15ac42a --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +Status SchemaToTypes(const std::vector& schema, DataTypeVector* dtypes) { + for (auto e : schema) { + if (e == BYTE || e == BYTE_ARR) { + dtypes->push_back(DT_UINT8); + } else if (e == SHORT || e == SHORT_ARR) { + dtypes->push_back(DT_INT16); + } else if (e == INT || e == INT_ARR) { + dtypes->push_back(DT_INT32); + } else if (e == LONG || e == LONG_ARR) { + dtypes->push_back(DT_INT64); + } else if (e == FLOAT || e == FLOAT_ARR) { + dtypes->push_back(DT_FLOAT); + } else if (e == DOUBLE || e == DOUBLE_ARR) { + dtypes->push_back(DT_DOUBLE); + } else if (e == USHORT || e == USHORT_ARR) { + dtypes->push_back(DT_UINT8); + } else if (e == BOOL || e == BOOL_ARR) { + dtypes->push_back(DT_BOOL); + } else if (e == STRING || e == STRING_ARR) { + dtypes->push_back(DT_STRING); + } else { + return errors::Unknown("Unexpected type in schema [type_id=", e, "]"); + } + } + + return Status::OK(); +} + +Status SchemaToShapes(const std::vector& schema, + std::vector* shapes) { + for (auto e : schema) { + if (e >= 1 && e < 10) { + shapes->push_back(PartialTensorShape({})); + } else if (e >= 12 && e < 21) { + shapes->push_back(PartialTensorShape({-1})); + } else { + return errors::Unknown("Unexpected type in schema [type_id=", e, "]"); + } + } + + return Status::OK(); +} + +class IgniteDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string cache_name = ""; + string host = ""; + int32 port = -1; + bool local = false; + int32 part = -1; + int32 page_size = -1; + string username = ""; + string password = ""; + string certfile = ""; + string keyfile = ""; + string cert_password = ""; + + const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME"); + const char* env_host = std::getenv("IGNITE_DATASET_HOST"); + const char* env_port = std::getenv("IGNITE_DATASET_PORT"); + const char* env_local = std::getenv("IGNITE_DATASET_LOCAL"); + const char* env_part = std::getenv("IGNITE_DATASET_PART"); + const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE"); + const char* env_username = std::getenv("IGNITE_DATASET_USERNAME"); + const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD"); + const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE"); + const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE"); + const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD"); + + if (env_cache_name) { + cache_name = string(env_cache_name); + } else { + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "cache_name", &cache_name)); + } + + if (env_host) { + host = string(env_host); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "host", &host)); + } + + if (env_port) { + OP_REQUIRES(ctx, strings::safe_strto32(env_port, &port), + errors::InvalidArgument("IGNITE_DATASET_PORT environment " + "variable is not a valid integer: ", + env_port)); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "port", &port)); + } + + if (env_local) { + local = true; + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "local", &local)); + } + + if (env_part) { + OP_REQUIRES(ctx, strings::safe_strto32(env_part, &part), + errors::InvalidArgument("IGNITE_DATASET_PART environment " + "variable is not a valid integer: ", + env_part)); + } else { + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "part", &part)); + } + + if (env_page_size) { + OP_REQUIRES(ctx, strings::safe_strto32(env_page_size, &page_size), + errors::InvalidArgument("IGNITE_DATASET_PAGE_SIZE " + "environment variable is not a valid " + "integer: ", + env_page_size)); + } else { + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "page_size", &page_size)); + } + + if (env_username) username = string(env_username); + + if (env_password) password = string(env_password); + + if (env_certfile) certfile = string(env_certfile); + + if (env_keyfile) keyfile = string(env_keyfile); + + if (env_cert_password) cert_password = string(env_cert_password); + + const Tensor* schema_tensor; + OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor)); + OP_REQUIRES(ctx, schema_tensor->dims() == 1, + errors::InvalidArgument("`schema` must be a vector.")); + + std::vector schema; + schema.reserve(schema_tensor->NumElements()); + for (int i = 0; i < schema_tensor->NumElements(); i++) { + schema.push_back(schema_tensor->flat()(i)); + } + + const Tensor* permutation_tensor; + OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor)); + OP_REQUIRES(ctx, permutation_tensor->dims() == 1, + errors::InvalidArgument("`permutation` must be a vector.")); + + std::vector permutation; + permutation.resize(permutation_tensor->NumElements()); + for (int i = 0; i < permutation_tensor->NumElements(); i++) { + // Inversed permutation. + permutation[permutation_tensor->flat()(i)] = i; + } + + DataTypeVector dtypes; + std::vector shapes; + + OP_REQUIRES_OK(ctx, SchemaToTypes(schema, &dtypes)); + OP_REQUIRES_OK(ctx, SchemaToShapes(schema, &shapes)); + + *output = new IgniteDataset( + ctx, std::move(cache_name), std::move(host), port, local, part, + page_size, std::move(username), std::move(password), + std::move(certfile), std::move(keyfile), std::move(cert_password), + std::move(schema), std::move(permutation), std::move(dtypes), + std::move(shapes)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU), + IgniteDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h new file mode 100644 index 0000000000000000000000000000000000000000..75424c19ee4b7df5378aa23cb41db1752e8d0651 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" + +namespace tensorflow { + +class PlainClient : public Client { + public: + PlainClient(string host, int port, bool big_endian); + ~PlainClient(); + + Status Connect() override; + Status Disconnect() override; + bool IsConnected() override; + int GetSocketDescriptor() override; + Status ReadData(uint8_t* buf, const int32_t length) override; + Status WriteData(const uint8_t* buf, const int32_t length) override; + + private: + const string host_; + const int port_; + int sock_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf672942c61e1239332711db12e62088737c4f41 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" + +#include +#include +#include +#include + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +PlainClient::PlainClient(string host, int port, bool big_endian) + : Client(big_endian), host_(std::move(host)), port_(port), sock_(-1) {} + +PlainClient::~PlainClient() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } +} + +Status PlainClient::Connect() { + if (sock_ == -1) { + sock_ = socket(AF_INET, SOCK_STREAM, 0); + if (sock_ == -1) return errors::Internal("Failed to create socket"); + } + + sockaddr_in server; + + server.sin_addr.s_addr = inet_addr(host_.c_str()); + if (server.sin_addr.s_addr == -1) { + hostent* he; + in_addr** addr_list; + + if ((he = gethostbyname(host_.c_str())) == NULL) + return errors::Internal("Failed to resolve hostname \"", host_, "\""); + + addr_list = (in_addr**)he->h_addr_list; + if (addr_list[0] != NULL) server.sin_addr = *addr_list[0]; + } + + server.sin_family = AF_INET; + server.sin_port = htons(port_); + + if (connect(sock_, (sockaddr*)&server, sizeof(server)) < 0) + return errors::Internal("Failed to connect to \"", host_, ":", port_, "\""); + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established"; + + return Status::OK(); +} + +Status PlainClient::Disconnect() { + int close_res = close(sock_); + sock_ = -1; + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" is closed"; + + return close_res == 0 + ? Status::OK() + : errors::Internal("Failed to correctly close connection"); +} + +bool PlainClient::IsConnected() { return sock_ != -1; } + +int PlainClient::GetSocketDescriptor() { return sock_; } + +Status PlainClient::ReadData(uint8_t* buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = recv(sock_, buf, length - received, 0); + + if (res < 0) + return errors::Internal("Error occurred while reading from socket: ", res, + ", ", string(strerror(errno))); + + if (res == 0) return errors::Internal("Server closed connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status PlainClient::WriteData(const uint8_t* buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = send(sock_, buf, length - sent, 0); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", res, + ", ", string(strerror(errno))); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc new file mode 100644 index 0000000000000000000000000000000000000000..dad5aace5fabe1df58bb9579bf578f4c35324315 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" + +#define WIN32_LEAN_AND_MEAN +#include +#include +#include + +#pragma comment(lib, "Ws2_32.lib") +#pragma comment(lib, "Mswsock.lib") +#pragma comment(lib, "AdvApi32.lib") + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +PlainClient::PlainClient(string host, int port, bool big_endian) + : Client(big_endian), + host_(std::move(host)), + port_(port), + sock_(INVALID_SOCKET) {} + +PlainClient::~PlainClient() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } +} + +Status PlainClient::Connect() { + WSADATA wsaData; + addrinfo *result = NULL, *ptr = NULL, hints; + + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) return errors::Internal("WSAStartup failed with error: ", res); + + ZeroMemory(&hints, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + res = getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints, + &result); + if (res != 0) return errors::Internal("Getaddrinfo failed with error: ", res); + + auto clean = gtl::MakeCleanup([result] { freeaddrinfo(result); }); + + for (ptr = result; ptr != NULL; ptr = ptr->ai_next) { + sock_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol); + if (sock_ == INVALID_SOCKET) { + WSACleanup(); + return errors::Internal("Socket failed with error: ", WSAGetLastError()); + } + + res = connect(sock_, ptr->ai_addr, (int)ptr->ai_addrlen); + if (res == SOCKET_ERROR) { + closesocket(sock_); + sock_ = INVALID_SOCKET; + continue; + } + + break; + } + + if (sock_ == INVALID_SOCKET) { + WSACleanup(); + return errors::Internal("Unable to connect to server"); + } + + LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established"; + + return Status::OK(); +} + +Status PlainClient::Disconnect() { + int res = shutdown(sock_, SD_SEND); + closesocket(sock_); + WSACleanup(); + + if (res == SOCKET_ERROR) + return errors::Internal("Shutdown failed with error: ", WSAGetLastError()); + else + return Status::OK(); +} + +bool PlainClient::IsConnected() { return sock_ != INVALID_SOCKET; } + +int PlainClient::GetSocketDescriptor() { return sock_; } + +Status PlainClient::ReadData(uint8_t *buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = recv(sock_, (char *)buf, length - received, 0); + + if (res < 0) + return errors::Internal("Error occurred while reading from socket: ", + res); + + if (res == 0) return errors::Internal("Server closed connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status PlainClient::WriteData(const uint8_t *buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = send(sock_, (char *)buf, length - sent, 0); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", + res); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceb479b0846574a35d86002ebb9c3e8e1d3687ac --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +static int PasswordCb(char *buf, int size, int rwflag, void *password) { + strncpy(buf, (char *)(password), size); + buf[size - 1] = '\0'; + return (strlen(buf)); +} + +SslWrapper::SslWrapper(std::shared_ptr client, string certfile, + string keyfile, string cert_password, bool big_endian) + : Client(big_endian), + client_(client), + certfile_(std::move(certfile)), + keyfile_(std::move(keyfile)), + cert_password_(std::move(cert_password)), + ctx_(nullptr), + ssl_(nullptr) {} + +SslWrapper::~SslWrapper() { + if (IsConnected()) { + Status status = Disconnect(); + if (!status.ok()) LOG(WARNING) << status.ToString(); + } + + if (ctx_ != nullptr) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + + if (ssl_ != nullptr) { + SSL_free(ssl_); + ssl_ = nullptr; + } +} + +Status SslWrapper::InitSslContext() { + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + + ctx_ = SSL_CTX_new(SSLv23_method()); + if (ctx_ == NULL) return errors::Internal("Couldn't create SSL context"); + + SSL_CTX_set_default_passwd_cb(ctx_, PasswordCb); + SSL_CTX_set_default_passwd_cb_userdata(ctx_, (void *)cert_password_.c_str()); + + if (SSL_CTX_use_certificate_chain_file(ctx_, certfile_.c_str()) != 1) + return errors::Internal("Couldn't load cetificate chain (file '", certfile_, + "')"); + + string private_key_file = keyfile_.empty() ? certfile_ : keyfile_; + if (SSL_CTX_use_PrivateKey_file(ctx_, private_key_file.c_str(), + SSL_FILETYPE_PEM) != 1) + return errors::Internal("Couldn't load private key (file '", + private_key_file, "')"); + + return Status::OK(); +} + +Status SslWrapper::Connect() { + if (ctx_ == NULL) { + TF_RETURN_IF_ERROR(InitSslContext()); + } + + ssl_ = SSL_new(ctx_); + if (ssl_ == NULL) + return errors::Internal("Failed to establish SSL connection"); + + TF_RETURN_IF_ERROR(client_->Connect()); + + SSL_set_fd(ssl_, client_->GetSocketDescriptor()); + if (SSL_connect(ssl_) != 1) + return errors::Internal("Failed to establish SSL connection"); + + LOG(INFO) << "SSL connection established"; + + return Status::OK(); +} + +Status SslWrapper::Disconnect() { + SSL_free(ssl_); + ssl_ = nullptr; + + LOG(INFO) << "SSL connection closed"; + + return client_->Disconnect(); +} + +bool SslWrapper::IsConnected() { return client_->IsConnected(); } + +int SslWrapper::GetSocketDescriptor() { return client_->GetSocketDescriptor(); } + +Status SslWrapper::ReadData(uint8_t *buf, const int32_t length) { + int received = 0; + + while (received < length) { + int res = SSL_read(ssl_, buf, length - received); + + if (res < 0) + return errors::Internal("Error occurred while reading from SSL socket: ", + res); + + if (res == 0) return errors::Internal("Server closed SSL connection"); + + received += res; + buf += res; + } + + return Status::OK(); +} + +Status SslWrapper::WriteData(const uint8_t *buf, const int32_t length) { + int sent = 0; + + while (sent < length) { + int res = SSL_write(ssl_, buf, length - sent); + + if (res < 0) + return errors::Internal("Error occurred while writing into socket: ", + res); + + sent += res; + buf += res; + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..0406644bbaab3de816540ce85e84b489ea9fff12 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ + +#include "tensorflow/contrib/ignite/kernels/ignite_client.h" + +#include + +namespace tensorflow { + +class SslWrapper : public Client { + public: + SslWrapper(std::shared_ptr client, string certfile, string keyfile, + string cert_password, bool big_endian); + ~SslWrapper(); + + Status Connect() override; + Status Disconnect() override; + bool IsConnected() override; + int GetSocketDescriptor() override; + Status ReadData(uint8_t* buf, const int32_t length) override; + Status WriteData(const uint8_t* buf, const int32_t length) override; + + private: + Status InitSslContext(); + + std::shared_ptr client_; + string certfile_; + string keyfile_; + string cert_password_; + SSL_CTX* ctx_; + SSL* ssl_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d6fbe00e6296941b4ce77d1238a79099bb9a5aa --- /dev/null +++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("IgniteDataset") + .Input("cache_name: string") + .Input("host: string") + .Input("port: int32") + .Input("local: bool") + .Input("part: int32") + .Input("page_size: int32") + .Input("schema: int32") + .Input("permutation: int32") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +IgniteDataset that allows to get data from Apache Ignite. + +Apache Ignite is a memory-centric distributed database, caching, and processing +platform for transactional, analytical, and streaming workloads, delivering +in-memory speeds at petabyte scale. This contrib package contains an +integration between Apache Ignite and TensorFlow. The integration is based on +tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side. +It allows to use Apache Ignite as a datasource for neural network training, +inference and all other computations supported by TensorFlow. Ignite Dataset +is based on Apache Ignite Binary Client Protocol. + +cache_name: Ignite Cache Name. +host: Ignite Thin Client Host. +port: Ignite Thin Client Port. +local: Local flag that defines that data should be fetched from local host only. +part: Partition data should be fetched from. +page_size: Page size for Ignite Thin Client. +schema: Internal structure that defines schema of cache objects. +permutation: Internal structure that defines permutation of cache objects. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..288d4853207176b215cd8a0cdcbfb2de5791ecb8 --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -0,0 +1,772 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignite Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import socket +import ssl +import struct + +from tensorflow.contrib.ignite.python.ops import gen_dataset_ops +from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class Readable(object): + """Readable abstract class that exposes methods to do reading-related + + operations. + """ + + @abc.abstractmethod + def __init__(self): + pass + + def read_byte(self): + """Reads and returnes byte.""" + return self._read("b", 1) + + def read_short(self): + """Reads and returns short (2 bytes, little-endian).""" + return self._read("h", 2) + + def read_int(self): + """Reads and returns int (4 bytes, little-endian).""" + return self._read("i", 4) + + def read_long(self): + """Reads and returns long (8 bytes, little-endian).""" + return self._read("q", 8) + + def skip(self, length): + """Skips the specified number of bytes.""" + self.read_data(length) + + @abc.abstractmethod + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + return None + + def _read(self, data_type, length): + """Reads, unpacks and returns specified type (little-endian).""" + data_buffer = self.read_data(length) + return struct.unpack("<" + data_type, data_buffer)[0] + + +class DataBuffer(Readable): + """DataBuffer class that exposes methods to read data from a byte buffer.""" + + def __init__(self, data_buffer): + """Constructs a new instance based on the specified byte buffer. + + Args: + data_buffer: Buffer to be read. + """ + Readable.__init__(self) + self.buffer = data_buffer + self.ptr = 0 + + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + data_buffer = self.buffer[self.ptr:][:length] + self.ptr += length + return data_buffer + + +class TcpClient(Readable): + """TcpClient class that exposes methods to read data from a socket.""" + + def __init__(self, host, port, certfile=None, keyfile=None, password=None): + """Constructs a new instance based on the specified host and port. + + Args: + host: Host to be connected. + port: Port to be connected. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + password: Password to be used if the private key is encrypted and a + password is necessary. + + Raises: + ValueError: If the wrong combination of arguments is provided. + """ + Readable.__init__(self) + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + if certfile is not None: + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.load_cert_chain(certfile, keyfile, password) + self.sock = context.wrap_socket(self.sock) + else: + if keyfile is not None: + raise ValueError("SSL is disabled, keyfile must not be specified " + "(to enable SSL specify certfile)") + if password is not None: + raise ValueError("SSL is disabled, password must not be specified " + "(to enable SSL specify certfile)") + + self.host = host + self.port = port + + def __enter__(self): + """Connects to host and port specified in the constructor.""" + self.sock.connect((self.host, self.port)) + return self + + def __exit__(self, t, v, traceback): + """Disconnects the socket.""" + self.sock.close() + + def write_byte(self, v): + """Writes the specified byte.""" + self._write(v, "b") + + def write_short(self, v): + """Writes the specified short (2 bytes, little-endian).""" + self._write(v, "h") + + def write_int(self, v): + """Writes the specified short (4 bytes, little-endian).""" + self._write(v, "i") + + def write_long(self, v): + """Writes the specified int (8 bytes, little-endian).""" + self._write(v, "q") + + def write_string(self, v): + """Writes the specified string.""" + self.sock.sendall(v.encode("UTF-8")) + + def read_data(self, length): + """Reads the specified number of bytes and returns them as a buffer.""" + data_buffer = None + rem = length + while rem > 0: + buf = self.sock.recv(rem) + rem = rem - len(buf) + if data_buffer is None: + data_buffer = buf + else: + data_buffer += buf + return data_buffer + + def _write(self, value, data_type): + """Packs and writes data using the specified type (little-endian).""" + data_buffer = struct.pack("<" + data_type, value) + self.sock.sendall(data_buffer) + + +class BinaryType(object): + """BinaryType class that encapsulated type id, type name and fields.""" + + def __init__(self, type_id, type_name, fields): + """Constructs a new instance of BinaryType.""" + self.type_id = type_id + self.type_name = type_name + self.fields = fields + + +class BinaryField(object): + """BinaryField class that encapsulated field name, type id and field id.""" + + def __init__(self, field_name, type_id, field_id): + """Constructs a new instance of BinaryField.""" + self.field_name = field_name + self.type_id = type_id + self.field_id = field_id + + +# Binary types defined in Apache Ignite Thin client and supported by +# TensorFlow on Apache Ignite, see +# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol. +# True means that type is a vector, False means type is scalar. +types = { + 1: (dtypes.uint8, False), + 2: (dtypes.int16, False), + 3: (dtypes.int32, False), + 4: (dtypes.int64, False), + 5: (dtypes.float32, False), + 6: (dtypes.float64, False), + 7: (dtypes.uint16, False), + 8: (dtypes.bool, False), + 9: (dtypes.string, False), + 12: (dtypes.uint8, True), + 13: (dtypes.int16, True), + 14: (dtypes.int32, True), + 15: (dtypes.int64, True), + 16: (dtypes.float32, True), + 17: (dtypes.float64, True), + 18: (dtypes.uint16, True), + 19: (dtypes.bool, True), + 20: (dtypes.string, True) +} + + +class TypeTreeNode(object): + """TypeTreeNode class exposes methods to format object tree structure + + data. + """ + + def __init__(self, name, type_id, fields=None, permutation=None): + """Constructs a new instance of TypeTreeNode. + + Args: + name: Name of the object tree node. + type_id: Type id of the object tree node. + fields: List of fields (children of the object tree node). + permutation: Permutation that should be applied to order object children. + """ + self.name = name + self.type_id = type_id + self.fields = fields + self.permutation = permutation + + def to_output_classes(self): + """Formats the tree object as required by `Dataset.output_classes`.""" + if self.fields is None: + return ops.Tensor + output_classes = {} + for field in self.fields: + output_classes[field.name] = field.to_output_classes() + return output_classes + + def to_output_shapes(self): + """Formats the tree object as required by `Dataset.output_shapes`.""" + if self.fields is None: + if self.type_id in types: + object_type = types[self.type_id] + is_array = object_type[1] + if is_array: + return tensor_shape.TensorShape([None]) + return tensor_shape.TensorShape([]) + raise ValueError("Unsupported type [type_id=%d]" % self.type_id) + output_shapes = {} + for field in self.fields: + output_shapes[field.name] = field.to_output_shapes() + return output_shapes + + def to_output_types(self): + """Formats the tree object as required by `Dataset.output_types`.""" + if self.fields is None: + if self.type_id in types: + object_type = types[self.type_id] + return object_type[0] + raise ValueError("Unsupported type [type_id=%d]" % self.type_id) + else: + output_types = {} + for field in self.fields: + output_types[field.name] = field.to_output_types() + return output_types + + def to_flat(self): + """Returns a list of node types.""" + return self.to_flat_rec([]) + + def to_permutation(self): + """Returns a permutation that should be applied to order object leaves.""" + correct_order_dict = {} + self.traversal_rec(correct_order_dict, 0) + object_order = [] + self.traversal_permutation_rec(object_order) + return [correct_order_dict[o] for o in object_order] + + def to_flat_rec(self, flat): + """Formats a list of leaf node types in pre-order.""" + if self.fields is None: + flat.append(self.type_id) + else: + for field in self.fields: + field.to_flat_rec(flat) + return flat + + def traversal_permutation_rec(self, permutation): + """Collects nodes in accordance with permutation.""" + if self.fields is None: + permutation.append(self) + else: + for idx in self.permutation: + field = self.fields[idx] + field.traversal_permutation_rec(permutation) + + def traversal_rec(self, d, i): + """Collects nodes in pre-order traversal.""" + if self.fields is None: + d[self] = i + i += 1 + else: + for field in self.fields: + i = field.traversal_rec(d, i) + return i + + +class IgniteClient(TcpClient): + """IgniteClient enables working with Apache Ignite using a thin client. + + This client works with assumption that all object in the cache + have the same structure (homogeneous objects) and the cache contains at + least one object. + """ + + def __init__(self, + host, + port, + username=None, + password=None, + certfile=None, + keyfile=None, + cert_password=None): + """Constructs a new instance of IgniteClient. + + Args: + host: Apache Ignite Thin client host to be connected. + port: Apache Ignite Thin client port to be connected. + username: Apache Ignite Thin Client authentication username. + password: Apache Ignite Thin Client authentication password. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + cert_password: Password to be used if the private key is encrypted and a + password is necessary. + """ + TcpClient.__init__(self, host, port, certfile, keyfile, cert_password) + self.username = username + self.password = password + + def handshake(self): + """Makes a handshake after connect and before any other calls.""" + msg_len = 8 + + if self.username is None: + msg_len += 1 + else: + msg_len += 5 + len(self.username) + + if self.password is None: + msg_len += 1 + else: + msg_len += 5 + len(self.password) + + self.write_int(msg_len) # Message length + self.write_byte(1) # Handshake operation + self.write_short(1) # Version (1.1.0) + self.write_short(1) + self.write_short(0) + self.write_byte(2) # Thin client + + if self.username is None: # Username + self.write_byte(101) + else: + self.write_byte(9) + self.write_int(len(self.username)) + self.write_string(self.username) + + if self.password is None: # Password + self.write_byte(101) + else: + self.write_byte(9) + self.write_int(len(self.password)) + self.write_string(self.password) + + self.read_int() # Result length + res = self.read_byte() + + if res != 1: + serv_ver_major = self.read_short() + serv_ver_minor = self.read_short() + serv_ver_patch = self.read_short() + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError( + "Handshake Error [result=%d, version=%d.%d.%d]" % + (res, serv_ver_major, serv_ver_minor, serv_ver_patch)) + else: + raise RuntimeError( + "Handshake Error [result=%d, version=%d.%d.%d, message='%s']" % + (res, serv_ver_major, serv_ver_minor, serv_ver_patch, err_msg)) + + def get_cache_type(self, cache_name): + """Collects type information about objects stored in the specified cache.""" + cache_name_hash = self._java_hash_code(cache_name) + self.write_int(25) # Message length + self.write_short(2000) # Operation code + self.write_long(0) # Request ID + self.write_int(cache_name_hash) # Cache name + self.write_byte(0) # Flags + self.write_byte(101) # Filter (NULL) + self.write_int(1) # Cursor page size + self.write_int(-1) # Partition to query + self.write_byte(0) # Local flag + + result_length = self.read_int() + self.read_long() # Request id + status = self.read_int() + + if status != 0: + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError("Scan Query Error [status=%s]" % status) + else: + raise RuntimeError( + "Scan Query Error [status=%s, message='%s']" % (status, err_msg)) + + self.read_long() # Cursor id + row_count = self.read_int() + + if row_count == 0: + raise RuntimeError("Scan Query returned empty result, so it's " + "impossible to derive the cache type") + + payload = DataBuffer(self.read_data(result_length - 25)) + + self.read_byte() # Next page + + res = TypeTreeNode("root", 0, [ + self._collect_types("key", payload), + self._collect_types("val", payload) + ], [0, 1]) + + return res + + def _java_hash_code(self, s): + """Computes hash code of the specified string using Java code.""" + h = 0 + for c in s: + h = (31 * h + ord(c)) & 0xFFFFFFFF + return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000 + + def _collect_types(self, field_name, data): + """Extracts type information from the specified object.""" + type_id = data.read_byte() + + # Byte scalar. + if type_id == 1: + data.skip(1) + return TypeTreeNode(field_name, type_id) + + # Short scalar. + if type_id == 2: + data.skip(2) + return TypeTreeNode(field_name, type_id) + + # Integer scalar. + if type_id == 3: + data.skip(4) + return TypeTreeNode(field_name, type_id) + + # Long scalar. + if type_id == 4: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Float scalar. + if type_id == 5: + data.skip(4) + return TypeTreeNode(field_name, type_id) + + # Double scalar. + if type_id == 6: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Char scalar. + if type_id == 7: + data.skip(2) + return TypeTreeNode(field_name, type_id) + + # Bool scalar. + if type_id == 8: + data.skip(1) + return TypeTreeNode(field_name, type_id) + + # String scalar. + if type_id == 9: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # UUID scalar. + if type_id == 10: + data.skip(16) + return TypeTreeNode(field_name, type_id) + + # Date scalar. + if type_id == 11: + data.skip(8) + return TypeTreeNode(field_name, type_id) + + # Byte array. + if type_id == 12: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # Short array. + if type_id == 13: + length = data.read_int() + data.skip(length * 2) + return TypeTreeNode(field_name, type_id) + + # Integer array. + if type_id == 14: + length = data.read_int() + data.skip(length * 4) + return TypeTreeNode(field_name, type_id) + + # Long array. + if type_id == 15: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Float array. + if type_id == 16: + length = data.read_int() + data.skip(length * 4) + return TypeTreeNode(field_name, type_id) + + # Double array. + if type_id == 17: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Char array. + if type_id == 18: + length = data.read_int() + data.skip(length * 2) + return TypeTreeNode(field_name, type_id) + + # Bool array. + if type_id == 19: + length = data.read_int() + data.skip(length) + return TypeTreeNode(field_name, type_id) + + # String array. + if type_id == 20: + length = data.read_int() + for _ in range(length): + header = data.read_byte() + if header == 9: + str_length = data.read_int() + data.skip(str_length) + elif header == 101: + pass + else: + raise RuntimeError( + "Unknown binary type when expected string [type_id=%d]" % header) + return TypeTreeNode(field_name, type_id) + + # UUID array. + if type_id == 21: + length = data.read_int() + data.skip(length * 16) # TODO(dmitrievanthony): support NULL values. + return TypeTreeNode(field_name, type_id) + + # Date array. + if type_id == 22: + length = data.read_int() + data.skip(length * 8) + return TypeTreeNode(field_name, type_id) + + # Wrapped Binary Object. + if type_id == 27: + length = data.read_int() + inner_data = data.read_data(length) + data.read_int() # Offset + return self._collect_types(field_name, DataBuffer(inner_data)) + + # Complex Object. + if type_id == 103: + data.read_byte() # Object version + data.read_short() # Object flags + obj_type_id = data.read_int() + data.read_int() # Object hash code + obj_length = data.read_int() + data.read_int() # Object schema id + obj_schema_offset = data.read_int() + + obj_type = self._get_type(obj_type_id) + children = [] + + for obj_field in obj_type.fields: + child = self._collect_types(obj_field.field_name, data) + children.append(child) + + children_sorted = sorted(children, key=lambda child: child.name) + permutation = [children_sorted.index(child) for child in children] + children = children_sorted + + data.skip(obj_length - obj_schema_offset) + + return TypeTreeNode(field_name, type_id, children, permutation) + + raise RuntimeError("Unknown binary type [type_id=%d]" % type_id) + + def _get_type(self, type_id): + """Queries Apache Ignite information about type by type id.""" + self.write_int(14) # Message length + self.write_short(3002) # Operation code + self.write_long(0) # Request ID + self.write_int(type_id) # Type ID + + self.read_int() # Result length + self.read_long() # Request id + status = self.read_int() + + if status != 0: + err_msg = self._parse_string() + if err_msg is None: + raise RuntimeError("Get Binary Type Error [status=%d, message='%s']" % + (status, err_msg)) + else: + raise RuntimeError("Get Binary Type Error [status=%d]" % status) + + binary_type_exists = self.read_byte() + + if binary_type_exists == 0: + raise RuntimeError("Binary type not found [type_id=%d] " % type_id) + + binary_type_id = self.read_int() + binary_type_name = self._parse_string() + self._parse_string() # Affinity field name + + fields = [] + for _ in range(self.read_int()): + field_name = self._parse_string() + field_type_id = self.read_int() + field_id = self.read_int() + + field = BinaryField(field_name, field_type_id, field_id) + fields.append(field) + + is_enum = self.read_byte() + if is_enum == 1: + raise RuntimeError("Enum fields are not supported yet") + + schema_cnt = self.read_int() + for _ in range(schema_cnt): + self.read_int() # Schema id + field_cnt = self.read_int() + self.skip(field_cnt * 4) + + return BinaryType(binary_type_id, binary_type_name, fields) + + def _parse_string(self): + """Parses string.""" + header = self.read_byte() + if header == 9: + length = self.read_int() + return self.read_data(length).decode("utf-8") + if header == 101: + return None + raise RuntimeError( + "Unknown binary type when expected string [type_id=%d]" % header) + + +class IgniteDataset(dataset_ops.DatasetSource): + """Apache Ignite is a memory-centric distributed database, caching, and + + processing platform for transactional, analytical, and streaming workloads, + delivering in-memory speeds at petabyte scale. This contrib package + contains an integration between Apache Ignite and TensorFlow. The + integration is based on tf.data from TensorFlow side and Binary Client + Protocol from Apache Ignite side. It allows to use Apache Ignite as a + datasource for neural network training, inference and all other + computations supported by TensorFlow. Ignite Dataset is based on Apache + Ignite Binary Client Protocol. + """ + + def __init__(self, + cache_name, + host="localhost", + port=10800, + local=False, + part=-1, + page_size=100, + username=None, + password=None, + certfile=None, + keyfile=None, + cert_password=None): + """Create a IgniteDataset. + + Args: + cache_name: Cache name to be used as datasource. + host: Apache Ignite Thin Client host to be connected. + port: Apache Ignite Thin Client port to be connected. + local: Local flag that defines to query only local data. + part: Number of partitions to be queried. + page_size: Apache Ignite Thin Client page size. + username: Apache Ignite Thin Client authentication username. + password: Apache Ignite Thin Client authentication password. + certfile: File in PEM format containing the certificate as well as any + number of CA certificates needed to establish the certificate's + authenticity. + keyfile: File containing the private key (otherwise the private key will + be taken from certfile as well). + cert_password: Password to be used if the private key is encrypted and a + password is necessary. + """ + super(IgniteDataset, self).__init__() + + with IgniteClient(host, port, username, password, certfile, keyfile, + cert_password) as client: + client.handshake() + self.cache_type = client.get_cache_type(cache_name) + + self.cache_name = ops.convert_to_tensor( + cache_name, dtype=dtypes.string, name="cache_name") + self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host") + self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port") + self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local") + self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part") + self.page_size = ops.convert_to_tensor( + page_size, dtype=dtypes.int32, name="page_size") + self.schema = ops.convert_to_tensor( + self.cache_type.to_flat(), dtype=dtypes.int32, name="schema") + self.permutation = ops.convert_to_tensor( + self.cache_type.to_permutation(), + dtype=dtypes.int32, + name="permutation") + + def _as_variant_tensor(self): + return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, + self.local, self.part, self.page_size, + self.schema, self.permutation) + + @property + def output_classes(self): + return self.cache_type.to_output_classes() + + @property + def output_shapes(self): + return self.cache_type.to_output_shapes() + + @property + def output_types(self): + return self.cache_type.to_output_types() diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py similarity index 94% rename from tensorflow/contrib/data/python/ops/contrib_op_loader.py rename to tensorflow/contrib/ignite/python/ops/ignite_op_loader.py index 8f495a9dc9c82311435e71d2ac9ed35fd9aea794..c9af7386cf0a26ed1a950130aa36caa7fb831fd0 100644 --- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Python helper for loading contrib ops and kernels.""" +"""Python helper for loading Ignite ops and kernels.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh new file mode 100755 index 0000000000000000000000000000000000000000..f4607ce8adab38c27d040ad1118858d17b924a6a --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml & +sleep 5 # Wait Apache Ignite to be started + +./apache-ignite-fabric/bin/sqlline.sh \ +-u "jdbc:ignite:thin://127.0.0.1/" \ +--run=/data/sql/init.sql + +tail -f nohup.out diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml new file mode 100644 index 0000000000000000000000000000000000000000..d900174a8abb3987c380e4a1a193ed81295fb88c --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + 127.0.0.1 + + + + + + + + + diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ef29b5f14a4b2fea2400ec4d56a7ad2cf44cf2cb --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for IgniteDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.ignite import IgniteDataset +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class IgniteDatasetTest(test.TestCase): + """The Apache Ignite servers have to setup before the test and tear down + + after the test manually. The docker engine has to be installed. + + To setup Apache Ignite servers: + $ bash start_ignite.sh + + To tear down Apache Ignite servers: + $ bash stop_ignite.sh + """ + + def test_ignite_dataset_with_plain_client(self): + """Test Ignite Dataset with plain client. + + """ + self._clear_env() + ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300) + self._check_dataset(ds) + + def _clear_env(self): + """Clears environment variables used by Ignite Dataset. + + """ + if "IGNITE_DATASET_USERNAME" in os.environ: + del os.environ["IGNITE_DATASET_USERNAME"] + if "IGNITE_DATASET_PASSWORD" in os.environ: + del os.environ["IGNITE_DATASET_PASSWORD"] + if "IGNITE_DATASET_CERTFILE" in os.environ: + del os.environ["IGNITE_DATASET_CERTFILE"] + if "IGNITE_DATASET_CERT_PASSWORD" in os.environ: + del os.environ["IGNITE_DATASET_CERT_PASSWORD"] + + def _check_dataset(self, dataset): + """Checks that dataset provides correct data.""" + self.assertEqual(dtypes.int64, dataset.output_types["key"]) + self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) + self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) + + it = dataset.make_one_shot_iterator() + ne = it.get_next() + + with session.Session() as sess: + rows = [sess.run(ne), sess.run(ne), sess.run(ne)] + with self.assertRaises(errors.OutOfRangeError): + sess.run(ne) + + self.assertEqual({"key": 1, "val": {"NAME": b"TEST1", "VAL": 42}}, rows[0]) + self.assertEqual({"key": 2, "val": {"NAME": b"TEST2", "VAL": 43}}, rows[1]) + self.assertEqual({"key": 3, "val": {"NAME": b"TEST3", "VAL": 44}}, rows[2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql new file mode 100644 index 0000000000000000000000000000000000000000..5a192aef17e22544e853cb78b4eb235beded42fe --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql @@ -0,0 +1,20 @@ +-- Copyright 2018 The TensorFlow Authors. All Rights Reserved. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- ============================================================================== + +CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG); + +INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42); +INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43); +INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44); diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh new file mode 100755 index 0000000000000000000000000000000000000000..a67bd44f2fb0d654ba07f022a5070c68df8e2ede --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +IGNITE_VERSION=2.6.0 +SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )" + +# Start Apache Ignite with plain client listener. +docker run -itd --name ignite-plain -p 42300:10800 \ +-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh old mode 100644 new mode 100755 similarity index 76% rename from tensorflow/contrib/linalg/python/__init__.py rename to tensorflow/contrib/ignite/python/tests/stop_ignite.sh index c5ca3a623fb15c44d04f2222708353d2934490e4..8f03dbd1ede61f548d3de9d9738f97667e75df3c --- a/tensorflow/contrib/linalg/python/__init__.py +++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""ops module.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +docker rm -f ignite-plain +docker rm -f ignite-ssl +docker rm -f ignite-ssl-auth diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 370a8caf6a71cc09629a5e75fd9151ae3f0f3b6d..788bf04b28aaad5d532258e0946fd03111384c69 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -156,6 +156,7 @@ namespace functor { TF_CALL_uint8(DECLARE_FUNCTOR); TF_CALL_int32(DECLARE_FUNCTOR); TF_CALL_int64(DECLARE_FUNCTOR); +TF_CALL_half(DECLARE_FUNCTOR); TF_CALL_float(DECLARE_FUNCTOR); TF_CALL_double(DECLARE_FUNCTOR); @@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR); TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index 6b63eed1303accc330293b3a44cdb9def7881666..7fac774d07fa8e07a0730ad018ba70e2c73a9cc5 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -71,14 +71,7 @@ class ProjectiveGenerator { (transform[3] * output_x + transform[4] * output_y + transform[5]) / projection; - // TODO(ringwalt): Add a fill value input. -#if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000) - // On CUDA versions previous to 8.0, only __shared__ variables - // could be declared as static in the device code. const T fill_value = T(0); -#else - static const T fill_value = T(0); -#endif switch (interpolation_) { case INTERPOLATION_NEAREST: // Switch the order of x and y again for indexing into the image. diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc index 8743a5ff724a5000ed0376045340f9ceaaccbfd2..36b9a236a6ea48e3b27dac956c93aecee321e2b7 100644 --- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc @@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice; template class FillProjectiveTransform; template class FillProjectiveTransform; template class FillProjectiveTransform; +template class FillProjectiveTransform; template class FillProjectiveTransform; template class FillProjectiveTransform; diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 376c0751eebb4906920ed338647630798d509113..4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -272,6 +272,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase): with self.cached_session(): self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval()) + def test_transform_data_types(self): + for dtype in _DTYPES: + image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual( + value.eval(), + np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py index 9ed017592afdcf6608833458eba192f616c9249d..f44edaa14c08a013e06ad7595b70dceb04950a04 100644 --- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class InputPipelineOpsTest(test.TestCase): def testObtainNext(self): - with self.test_session(): + with self.cached_session(): var = state_ops.variable_op([], dtypes.int64) state_ops.assign(var, -1).op.run() c = constant_op.constant(["a", "b"]) @@ -45,7 +45,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNext(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list) session.run([variables.global_variables_initializer()]) self.assertEqual(b"a", session.run(elem)) @@ -65,7 +65,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNextLimitEpochs(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) session.run([ variables.local_variables_initializer(), @@ -75,7 +75,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNextLimitEpochsThree(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) session.run([ variables.local_variables_initializer(), diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py index 621911876fc502ece76b08eb6c28697b3c12c863..08ebcdb544645d3585a1af25c86c6182a1589dcb 100644 --- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -54,7 +54,7 @@ class KafkaDatasetTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from topic 0. sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) for i in range(5): diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index a1624614d1ab1be31463c5cdc0b4cfb653165a0c..7129f09e8b42e48a9c768fd4a66cde3d4da9d31d 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -17,15 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import from tensorflow.contrib.kafka.python.ops import gen_dataset_ops -from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -class KafkaDataset(Dataset): +class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index 72507539f813d14064bc58f03b6db4781abc9438..4d5cc24ce0926486011814bc78a47da9db478bf1 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLogitsShape(self): """An error is raised when logits have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2,)) labels = constant_op.constant([0, 1]) with self.assertRaises(ValueError): @@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLabelsShape(self): """An error is raised when labels have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(1, 1, 2)) with self.assertRaises(ValueError): @@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidWeightsShape(self): """An error is raised when weights have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(2,)) weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1)) @@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLabelsDtype(self): """An error is raised when labels have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], dtype=dtypes.float32) with self.assertRaises(ValueError): @@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNoneWeightRaisesValueError(self): """An error is raised when weights are None.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0]) with self.assertRaises(ValueError): @@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInconsistentLabelsAndWeightsShapesSameRank(self): """Error raised when weights and labels have same ranks, different sizes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1)) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) weights = constant_op.constant([1.1, 2.0], shape=(2, 1)) @@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInconsistentLabelsAndWeightsShapesDifferentRank(self): """Error raised when weights and labels have different ranks and sizes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(2, 1)) weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,)) @@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testOutOfRangeLabels(self): """An error is raised when labels are not in [0, num_classes).""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) labels = constant_op.constant([1, 0, 4]) @@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testZeroLossInt32Labels(self): """Loss is 0 if true class logits sufficiently higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32) @@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testZeroLossInt64Labels(self): """Loss is 0 if true class logits sufficiently higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0], [-0.5, 0.8, -1.0]]) labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64) @@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): ] for batch_size, num_classes in logits_shapes: - with self.test_session(): + with self.cached_session(): logits = array_ops.placeholder( dtypes.float32, shape=(batch_size, num_classes)) labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) @@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testCorrectPredictionsSomeClassesInsideMargin(self): """Loss is > 0 even if true class logits are higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0], [1.5, 1.8, -1.0]]) labels = constant_op.constant([0, 2, 1]) @@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictions(self): """Loss is >0 when an incorrect class has higher logits than true class.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0], [0.5, -1.8, 2.0]]) labels = constant_op.constant([1, 0, 2]) @@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictionsColumnLabels(self): """Same as above but labels is a rank-2 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictionsZeroWeights(self): """Loss is 0 when all weights are missing even if predictions are wrong.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWithPythonScalarWeights(self): """Weighted loss is correctly computed when weights is a python scalar.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWithScalarTensorWeights(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWith1DTensorWeightsColumnLabels(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]]) labels = constant_op.constant([1, 0, 2, 1]) diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 2ff4d41d75fe59fb765a83e1b6a5b3eaad9d9163..bad0a596a78f0d1b9833f670ed256350f24f450d 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): def testInvalidInputShape(self): x = constant_op.constant([[2.0, 1.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 10) with self.assertRaisesWithPredicateMatch( dense_kernel_mapper.InvalidShapeError, @@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0], [4.0, -2.0, -1.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 10, 1.0) mapped_x1 = rffm.map(x1) mapped_x2 = rffm.map(x2) @@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): def testSameOmegaReused(self): x = constant_op.constant([[2.0, 1.0, 0.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 100) mapped_x = rffm.map(x) mapped_x_copy = rffm.map(x) @@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): y = constant_op.constant([[1.0, -1.0, 2.0]]) stddev = 3.0 - with self.test_session(): + with self.cached_session(): # The mapped dimension is fairly small, so the kernel approximation is # very rough. rffm1 = RandomFourierFeatureMapper(3, 100, stddev) @@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): y = constant_op.constant([[1.0, -1.0, 2.0]]) stddev = 3.0 - with self.test_session(): + with self.cached_session(): # The mapped dimension is fairly small, so the kernel approximation is # very rough. rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0) @@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): normalized_points = [nn.l2_normalize(point, dim=1) for point in points] total_absolute_error = 0.0 - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0) # Cache mappings so that they are not computed multiple times. cached_mappings = dict((point, rffm.map(point)) diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py index 7289b45c50fa92455b4c317b8a039ca414fa585e..bf89922318b9b9a569e4bd1d71fe6283810cadda 100644 --- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -64,7 +64,7 @@ class KinesisDatasetTest(test.TestCase): init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from shard 0 of stream 1. sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1}) for i in range(10): @@ -108,7 +108,7 @@ class KinesisDatasetTest(test.TestCase): get_next = iterator.get_next() data = list() - with self.test_session() as sess: + with self.cached_session() as sess: # Basic test: read from shard 0 of stream 2. sess.run( init_op, feed_dict={ diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index ca2df95ba4f20ec5fa58ff13530096e6e065f4fe..75806dbbeb1819bb0a6965bbc384e02df9895210 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -17,15 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops -from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -class KinesisDataset(Dataset): +class KinesisDataset(dataset_ops.DatasetSource): """A Kinesis Dataset that consumes the message. Kinesis is a managed service provided by AWS for data streaming. diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py index 28ddaa69a14776e0c157c2e68105ee9e17bc3cbb..155d06a08e6195f5032884ddd986374de58c66cb 100644 --- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py +++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py @@ -45,7 +45,7 @@ class SparseCrossOpTest(test.TestCase): 'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2', 'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_dense(self): @@ -66,7 +66,7 @@ class SparseCrossOpTest(test.TestCase): 'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2', 'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_integer_mixed_string_sparse(self): @@ -80,7 +80,7 @@ class SparseCrossOpTest(test.TestCase): '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_integer_mixed_string_dense(self): @@ -99,7 +99,7 @@ class SparseCrossOpTest(test.TestCase): '55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2', '999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_sparse_cross_dense(self): @@ -117,7 +117,7 @@ class SparseCrossOpTest(test.TestCase): 'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2', 'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_integer_sparse_input(self): @@ -133,7 +133,7 @@ class SparseCrossOpTest(test.TestCase): '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_permutation_3x3x3(self): @@ -176,7 +176,7 @@ class SparseCrossOpTest(test.TestCase): 'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2', 'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_permutation_3x1x2(self): @@ -196,7 +196,7 @@ class SparseCrossOpTest(test.TestCase): 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1', 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2' ]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_large_batch(self): @@ -229,7 +229,7 @@ class SparseCrossOpTest(test.TestCase): ]) expected_out = self._sparse_tensor(col_out) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_one_column_empty(self): @@ -242,7 +242,7 @@ class SparseCrossOpTest(test.TestCase): self._sparse_tensor([], 1), self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) ]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_empty(sess.run(op)) def test_some_columns_empty(self): @@ -261,7 +261,7 @@ class SparseCrossOpTest(test.TestCase): 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1', 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2' ]], 2) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_all_columns_empty(self): @@ -273,7 +273,7 @@ class SparseCrossOpTest(test.TestCase): self._sparse_tensor([]), self._sparse_tensor([]), self._sparse_tensor([]) ]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_empty(sess.run(op)) def test_hashed_output_zero_bucket(self): @@ -288,7 +288,7 @@ class SparseCrossOpTest(test.TestCase): hashed_output=True) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[3735511728867393167]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_hashed_output_zero_bucket_v2(self): @@ -304,7 +304,7 @@ class SparseCrossOpTest(test.TestCase): hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[1971693436396284976]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) # TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed. @@ -321,7 +321,7 @@ class SparseCrossOpTest(test.TestCase): num_buckets=100) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[74]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_hashed_output_v2(self): @@ -338,7 +338,7 @@ class SparseCrossOpTest(test.TestCase): hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY) # Check actual hashed output to prevent unintentional hashing changes. expected_out = self._sparse_tensor([[83]]) - with self.test_session() as sess: + with self.cached_session() as sess: self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_hashed_output_v1_has_collision(self): @@ -384,7 +384,7 @@ class SparseCrossOpTest(test.TestCase): ], hashed_output=True, num_buckets=1000) - with self.test_session() as sess: + with self.cached_session() as sess: out = sess.run(op) self.assertEqual(6, len(out.values)) self.assertAllEqual([[0, i] for i in range(6)], out.indices) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 85af9de4e4f6b4a90dd51830a482db357541f466..3b7ae72e9c460ee7a38f72b03e1c1ad48e335f57 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -2360,7 +2360,7 @@ class BatchNormTest(test.TestCase): batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) - is_training = variables_lib.Variable(True) + is_training = variables_lib.VariableV1(True) output = _layers.batch_norm( images, decay=0.1, @@ -2507,7 +2507,7 @@ class BatchNormTest(test.TestCase): batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) - is_training = variables_lib.Variable(True) + is_training = variables_lib.VariableV1(True) output = _layers.batch_norm( images, decay=0.1, diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 69d927e1b3001d14dd1af2f890b07c1a57ab2cfc..2fdcd849b026d52ed4aff724838f6c71e3a315d0 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -21,8 +21,6 @@ from __future__ import print_function import six from tensorflow.contrib import framework as contrib_framework -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -433,12 +431,11 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers): if (grad is not None and (var in gradient_multipliers or var.name in gradient_multipliers)): key = var if var in gradient_multipliers else var.name - multiplier = constant_op.constant( - gradient_multipliers[key], dtype=dtypes.float32) + multiplier = gradient_multipliers[key] if isinstance(grad, ops.IndexedSlices): grad_values = grad.values * multiplier grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape) else: - grad *= multiplier + grad *= math_ops.cast(multiplier, grad.dtype) multiplied_grads_and_vars.append((grad, var)) return multiplied_grads_and_vars diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 29dede2a495d4364bc5da161135243b7bff7a7f3..b4d1239e768cb2cc19eb058cf36ccc7267a86a42 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -250,6 +250,42 @@ class OptimizersTest(test.TestCase): self.assertAlmostEqual(var_value, 6.5, 4) self.assertEqual(global_step_value, 1) + def testGradientMultiplyInt32Tensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float32, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + + def testGradientMultiplyInt64Tensor(self): + with self.cached_session() as session: + x, var, loss, global_step = _setup_model() + v = array_ops.placeholder(dtypes.float64, []) + train = optimizers_lib.optimize_loss( + loss, + global_step, + learning_rate=0.1, + optimizer="SGD", + gradient_multipliers={var: v}) + variables.global_variables_initializer().run() + session.run(train, feed_dict={x: 5, v: 7.}) + var_value, global_step_value = session.run([var, global_step]) + # var(0) = 10, x = 5, var(0)/dx = 5, + # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx + self.assertAlmostEqual(var_value, 6.5, 4) + self.assertEqual(global_step_value, 1) + def testIgnoreVariablesWithNoGradients(self): _, _, loss, global_step = _setup_model() diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 69bb6be81453f5f5487f25547f017dc5f87c2f2c..8a6b4f68a8b33d497ddb16614a7e3cdf32f2c422 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -396,7 +396,7 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn): def _mean_squared_loss(logits, target): # To prevent broadcasting inside "-". if len(target.get_shape()) == 1: - target = array_ops.expand_dims(target, dim=[1]) + target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) return math_ops.square(logits - math_ops.to_float(target)) @@ -405,7 +405,7 @@ def _mean_squared_loss(logits, target): def _log_loss_with_two_classes(logits, target): # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. if len(target.get_shape()) == 1: - target = array_ops.expand_dims(target, dim=[1]) + target = array_ops.expand_dims(target, axis=1) loss_vec = nn.sigmoid_cross_entropy_with_logits( labels=math_ops.to_float(target), logits=logits) return loss_vec diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index ded93d4a7fb473c0c5df446ea89c5ab7784e9f3c..c6f79e00d5a5a584b0c5f8201a2576f02106a5b4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, axis=(1,)) + labels = array_ops.expand_dims(labels, axis=1) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, axis=(1,)) + logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.square(logits - math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) @@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, axis=(1,)) + labels = array_ops.expand_dims(labels, axis=1) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, axis=(1,)) + logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, name=name) @@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, axis=(1,)) + labels = array_ops.expand_dims(labels, axis=1) loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index d5c02124ac6a626de5e158b4dbe388a063ce4692..a160cb54a3954c73d35f0d3fb5b437c5f9f08984 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -162,9 +162,9 @@ class GraphActionsTest(test.TestCase): Tuple of 3 `Tensor` objects, 2 input and 1 output. """ variables_lib.create_global_step() - in0 = variables.Variable(1.0) + in0 = variables.VariableV1(1.0) in1 = variables_lib.local_variable(2.0) - fake_table = variables.Variable( + fake_table = variables.VariableV1( 3.0, trainable=False, collections=['fake_tables'], @@ -234,7 +234,7 @@ class GraphActionsTest(test.TestCase): self.assertTrue(test_ops.resource_initialized_op(handle).eval()) def test_infer_different_default_graph(self): - with self.test_session(): + with self.cached_session(): self._assert_ckpt(self._output_dir, False) with ops.Graph().as_default(): in0, in1, out = self._build_inference_graph() @@ -312,8 +312,8 @@ class GraphActionsTest(test.TestCase): def test_evaluate_ready_for_local_init(self): with ops.Graph().as_default() as g, self.session(g): variables_lib.create_global_step() - v = variables.Variable(1.0) - variables.Variable( + v = variables.VariableV1(1.0) + variables.VariableV1( v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False) ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) @@ -456,9 +456,9 @@ class GraphActionsTrainTest(test.TestCase): Tuple of 3 `Tensor` objects, 2 input and 1 output. """ variables_lib.create_global_step() - in0 = variables.Variable(1.0) + in0 = variables.VariableV1(1.0) in1 = variables_lib.local_variable(2.0) - fake_table = variables.Variable( + fake_table = variables.VariableV1( 3.0, trainable=False, collections=['fake_tables'], diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 83e48a36e71caae7474f6bb8a33379ab75f7abcf..d4a7169bb632f682a19268529fee56b01ab5fbcb 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -247,7 +247,7 @@ class MonitorsTest(test.TestCase): def test_logging_trainable(self): with ops.Graph().as_default() as g, self.session(g): - var = variables.Variable(constant_op.constant(42.0), name='foo') + var = variables.VariableV1(constant_op.constant(42.0), name='foo') var.initializer.run() cof = constant_op.constant(1.0) loss = math_ops.subtract( @@ -261,7 +261,7 @@ class MonitorsTest(test.TestCase): with ops.Graph().as_default() as g, self.session(g): log_dir = 'log/dir' summary_writer = testing.FakeSummaryWriter(log_dir, g) - var = variables.Variable(0.0) + var = variables.VariableV1(0.0) var.initializer.run() tensor = state_ops.assign_add(var, 1.0) summary_op = summary.scalar('my_summary', tensor) @@ -526,8 +526,8 @@ class MonitorsTest(test.TestCase): monitor0 = learn.monitors.GraphDump() monitor1 = learn.monitors.GraphDump() with ops.Graph().as_default() as g, self.session(g): - const_var = variables.Variable(42.0, name='my_const') - counter_var = variables.Variable(0.0, name='my_counter') + const_var = variables.VariableV1(42.0, name='my_const') + counter_var = variables.VariableV1(0.0, name='my_counter') assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add') variables.global_variables_initializer().run() @@ -569,7 +569,7 @@ class MonitorsTest(test.TestCase): monitor = learn.monitors.CaptureVariable( var_name='my_assign_add:0', every_n=8, first_n=2) with ops.Graph().as_default() as g, self.session(g): - var = variables.Variable(0.0, name='my_var') + var = variables.VariableV1(0.0, name='my_var') var.initializer.run() state_ops.assign_add(var, 1.0, name='my_assign_add') self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 2f33a2b74d44ef4684b2e86d54db7a0363e402d5..0e5ea6b9f714c8feacdc156617622be5e0d079f0 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -47,7 +47,7 @@ from tensorflow.python.training import adam class Seq2SeqTest(test.TestCase): def testRNNDecoder(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 @@ -65,7 +65,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testBasicRNNSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 @@ -81,7 +81,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testTiedRNNSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 @@ -98,7 +98,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testEmbeddingRNNDecoder(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 @@ -124,7 +124,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].h.shape) def testEmbeddingRNNSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ @@ -228,7 +228,7 @@ class Seq2SeqTest(test.TestCase): self.assertAllClose(res1, res3) def testEmbeddingTiedRNNSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ @@ -316,7 +316,7 @@ class Seq2SeqTest(test.TestCase): self.assertAllClose(res1, res3) def testAttentionDecoder1(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): cell_fn = lambda: rnn_cell.GRUCell(2) @@ -341,7 +341,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testAttentionDecoder2(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): cell_fn = lambda: rnn_cell.GRUCell(2) @@ -367,7 +367,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testDynamicAttentionDecoder1(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): cell_fn = lambda: rnn_cell.GRUCell(2) @@ -391,7 +391,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testDynamicAttentionDecoder2(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): cell_fn = lambda: rnn_cell.GRUCell(2) @@ -416,7 +416,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testAttentionDecoderStateIsTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda @@ -448,7 +448,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0][1].h.shape) def testDynamicAttentionDecoderStateIsTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda @@ -479,7 +479,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0][1].h.shape) def testEmbeddingAttentionDecoder(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 @@ -513,7 +513,7 @@ class Seq2SeqTest(test.TestCase): self.assertEqual((2, 2), res[0].shape) def testEmbeddingAttentionSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ @@ -622,7 +622,7 @@ class Seq2SeqTest(test.TestCase): # self.assertAllClose(res1, res3) def testOne2ManyRNNSeq2Seq(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): enc_inp = [ @@ -712,7 +712,7 @@ class Seq2SeqTest(test.TestCase): self.assertAllClose(res1, res3) def testSequenceLoss(self): - with self.test_session() as sess: + with self.cached_session() as sess: logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)] targets = [ constant_op.constant( @@ -748,7 +748,7 @@ class Seq2SeqTest(test.TestCase): self.assertAllClose(9.656628, res) def testSequenceLossByExample(self): - with self.test_session() as sess: + with self.cached_session() as sess: output_classes = 5 logits = [ constant_op.constant( @@ -778,7 +778,7 @@ class Seq2SeqTest(test.TestCase): # classes = 10 # buckets = [(4, 4), (8, 8)] - # with self.test_session(): + # with self.cached_session(): # # Here comes a sample Seq2Seq model using GRU cells. # def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss): # """Example sequence-to-sequence model that uses GRU cells.""" @@ -839,7 +839,7 @@ class Seq2SeqTest(test.TestCase): random.seed(111) np.random.seed(111) - with self.test_session() as sess: + with self.cached_session() as sess: # We use sampled softmax so we keep output projection separate. w = variable_scope.get_variable("proj_w", [24, classes]) w_t = array_ops.transpose(w) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD deleted file mode 100644 index 78b7970069fec2d67f816b39d8fa4c58021cef85..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/linalg/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Description: -# Contains classes that provide access to common method of a [batch] matrix, -# without the need to instantiate the matrix. -# This allows for exploitation of structure, as well as a generic interface -# suitable for iterative solvers. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -py_library( - name = "linalg_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:util", - "//tensorflow/python/ops/linalg", - "@six_archive//:six", - ], -) - -cuda_py_test( - name = "linear_operator_addition_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_addition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py deleted file mode 100644 index cbe4c03e4d1b4b3c0b773d78bc505e9cb1161ab3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/linalg/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Linear algebra libraries. - -See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg) -guide. - -@@LinearOperator -@@LinearOperatorBlockDiag -@@LinearOperatorCirculant -@@LinearOperatorCirculant2D -@@LinearOperatorCirculant3D -@@LinearOperatorDiag -@@LinearOperatorIdentity -@@LinearOperatorScaledIdentity -@@LinearOperatorFullMatrix -@@LinearOperatorKronecker -@@LinearOperatorLowerTriangular -@@LinearOperatorLowRankUpdate -@@LinearOperatorComposition -@@add_operators - -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member - -from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.python.ops.linalg.linear_operator import * -from tensorflow.python.ops.linalg.linear_operator_block_diag import * -from tensorflow.python.ops.linalg.linear_operator_circulant import * -from tensorflow.python.ops.linalg.linear_operator_composition import * -from tensorflow.python.ops.linalg.linear_operator_diag import * -from tensorflow.python.ops.linalg.linear_operator_full_matrix import * -from tensorflow.python.ops.linalg.linear_operator_identity import * -from tensorflow.python.ops.linalg.linear_operator_kronecker import * -from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * -from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * - -# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member - -from tensorflow.python.util.all_util import remove_undocumented - -remove_undocumented(__name__) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py deleted file mode 100644 index 6a72df6dfd8d8c35211bab42b240b83d77160a02..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py +++ /dev/null @@ -1,412 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.linalg.python.ops import linear_operator_addition -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops.linalg import linalg as linalg_lib -from tensorflow.python.platform import test - -linalg = linalg_lib -random_seed.set_random_seed(23) -rng = np.random.RandomState(0) - -add_operators = linear_operator_addition.add_operators - - -# pylint: disable=unused-argument -class _BadAdder(linear_operator_addition._Adder): - """Adder that will fail if used.""" - - def can_add(self, op1, op2): - raise AssertionError("BadAdder.can_add called!") - - def _add(self, op1, op2, operator_name, hints): - raise AssertionError("This line should not be reached") - - -# pylint: enable=unused-argument - - -class LinearOperatorAdditionCorrectnessTest(test.TestCase): - """Tests correctness of addition with combinations of a few Adders. - - Tests here are done with the _DEFAULT_ADDITION_TIERS, which means - add_operators should reduce all operators resulting in one single operator. - - This shows that we are able to correctly combine adders using the tiered - system. All Adders should be tested separately, and there is no need to test - every Adder within this class. - """ - - def test_one_operator_is_returned_unchanged(self): - op_a = linalg.LinearOperatorDiag([1., 1.]) - op_sum = add_operators([op_a]) - self.assertEqual(1, len(op_sum)) - self.assertTrue(op_sum[0] is op_a) - - def test_at_least_one_operators_required(self): - with self.assertRaisesRegexp(ValueError, "must contain at least one"): - add_operators([]) - - def test_attempting_to_add_numbers_raises(self): - with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"): - add_operators([1, 2]) - - def test_two_diag_operators(self): - op_a = linalg.LinearOperatorDiag( - [1., 1.], is_positive_definite=True, name="A") - op_b = linalg.LinearOperatorDiag( - [2., 2.], is_positive_definite=True, name="B") - with self.test_session(): - op_sum = add_operators([op_a, op_b]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) - self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval()) - # Adding positive definite operators produces positive def. - self.assertTrue(op.is_positive_definite) - # Real diagonal ==> self-adjoint. - self.assertTrue(op.is_self_adjoint) - # Positive definite ==> non-singular - self.assertTrue(op.is_non_singular) - # Enforce particular name for this simple case - self.assertEqual("Add/B__A/", op.name) - - def test_three_diag_operators(self): - op1 = linalg.LinearOperatorDiag( - [1., 1.], is_positive_definite=True, name="op1") - op2 = linalg.LinearOperatorDiag( - [2., 2.], is_positive_definite=True, name="op2") - op3 = linalg.LinearOperatorDiag( - [3., 3.], is_positive_definite=True, name="op3") - with self.test_session(): - op_sum = add_operators([op1, op2, op3]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) - self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) - # Adding positive definite operators produces positive def. - self.assertTrue(op.is_positive_definite) - # Real diagonal ==> self-adjoint. - self.assertTrue(op.is_self_adjoint) - # Positive definite ==> non-singular - self.assertTrue(op.is_non_singular) - - def test_diag_tril_diag(self): - op1 = linalg.LinearOperatorDiag( - [1., 1.], is_non_singular=True, name="diag_a") - op2 = linalg.LinearOperatorLowerTriangular( - [[2., 0.], [0., 2.]], - is_self_adjoint=True, - is_non_singular=True, - name="tril") - op3 = linalg.LinearOperatorDiag( - [3., 3.], is_non_singular=True, name="diag_b") - with self.test_session(): - op_sum = add_operators([op1, op2, op3]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular)) - self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) - - # The diag operators will be self-adjoint (because real and diagonal). - # The TriL operator has the self-adjoint hint set. - self.assertTrue(op.is_self_adjoint) - - # Even though op1/2/3 are non-singular, this does not imply op is. - # Since no custom hint was provided, we default to None (unknown). - self.assertEqual(None, op.is_non_singular) - - def test_matrix_diag_tril_diag_uses_custom_name(self): - op0 = linalg.LinearOperatorFullMatrix( - [[-1., -1.], [-1., -1.]], name="matrix") - op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") - op2 = linalg.LinearOperatorLowerTriangular( - [[2., 0.], [1.5, 2.]], name="tril") - op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") - with self.test_session(): - op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix)) - self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval()) - self.assertEqual("my_operator", op.name) - - def test_incompatible_domain_dimensions_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(2, 4)) - with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"): - add_operators([op1, op2]) - - def test_incompatible_range_dimensions_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(3, 3)) - with self.assertRaisesRegexp(ValueError, "must.*same range dimension"): - add_operators([op1, op2]) - - def test_non_broadcastable_batch_shape_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3)) - with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): - add_operators([op1, op2]) - - -class LinearOperatorOrderOfAdditionTest(test.TestCase): - """Test that the order of addition is done as specified by tiers.""" - - def test_tier_0_additions_done_in_tier_0(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - diag3 = linalg.LinearOperatorDiag([1.]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - ] - # Should not raise since all were added in tier 0, and tier 1 (with the - # _BadAdder) was never reached. - op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag)) - - def test_tier_1_additions_done_by_tier_1(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [linear_operator_addition._AddAndReturnTriL()], - [_BadAdder()], - ] - # Should not raise since all were added by tier 1, and the - # _BadAdder) was never reached. - op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) - - def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnTriL()], - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - ] - # Tier 0 could convert to TriL, and this converted everything to TriL, - # including the Diags. - # Tier 1 was never used. - # Tier 2 was never used (therefore, _BadAdder didn't raise). - op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) - - def test_cannot_add_everything_so_return_more_than_one_operator(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([2.]) - tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - ] - # Tier 0 (the only tier) can only convert to Diag, so it combines the two - # diags, but the TriL is unchanged. - # Result should contain two operators, one Diag, one TriL. - op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) - self.assertEqual(2, len(op_sum)) - found_diag = False - found_tril = False - with self.test_session(): - for op in op_sum: - if isinstance(op, linalg.LinearOperatorDiag): - found_diag = True - self.assertAllClose([[3.]], op.to_dense().eval()) - if isinstance(op, linalg.LinearOperatorLowerTriangular): - found_tril = True - self.assertAllClose([[5.]], op.to_dense().eval()) - self.assertTrue(found_diag and found_tril) - - def test_intermediate_tier_is_not_skipped(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - [linear_operator_addition._AddAndReturnTriL()], - ] - # tril cannot be added in tier 0, and the intermediate tier 1 with the - # BadAdder will catch it and raise. - with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): - add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - - -class AddAndReturnScaledIdentityTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnScaledIdentity() - - def test_identity_plus_identity(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2) - id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.test_session(): - self.assertAllClose(2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_identity_plus_scaled_identity(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.test_session(): - self.assertAllClose(3.2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_scaled_identity_plus_scaled_identity(self): - id1 = linalg.LinearOperatorScaledIdentity( - num_rows=2, multiplier=[2.2, 2.2, 2.2]) - id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.test_session(): - self.assertAllClose(1.2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnDiagTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnDiag() - - def test_identity_plus_identity_returns_diag(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2) - id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag)) - - with self.test_session(): - self.assertAllClose(2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_diag_plus_diag(self): - diag1 = rng.rand(2, 3, 4) - diag2 = rng.rand(4) - op1 = linalg.LinearOperatorDiag(diag1) - op2 = linalg.LinearOperatorDiag(diag2) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(op1, op2)) - operator = self._adder.add(op1, op2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag)) - - with self.test_session(): - self.assertAllClose( - linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnTriLTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnTriL() - - def test_diag_plus_tril(self): - diag = linalg.LinearOperatorDiag([1., 2.]) - tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(diag, diag)) - self.assertTrue(self._adder.can_add(diag, tril)) - operator = self._adder.add(diag, tril, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular)) - - with self.test_session(): - self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnMatrixTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnMatrix() - - def test_diag_plus_diag(self): - diag1 = linalg.LinearOperatorDiag([1., 2.]) - diag2 = linalg.LinearOperatorDiag([-1., 3.]) - hints = linear_operator_addition._Hints( - is_positive_definite=False, is_non_singular=False) - - self.assertTrue(self._adder.can_add(diag1, diag2)) - operator = self._adder.add(diag1, diag2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix)) - - with self.test_session(): - self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval()) - self.assertFalse(operator.is_positive_definite) - self.assertFalse(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py deleted file mode 100644 index 86130a2c077ce14a7539b281ec809029bc05e071..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Add one or more `LinearOperators` efficiently.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc - -import six - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops.linalg import linear_operator -from tensorflow.python.ops.linalg import linear_operator_diag -from tensorflow.python.ops.linalg import linear_operator_full_matrix -from tensorflow.python.ops.linalg import linear_operator_identity -from tensorflow.python.ops.linalg import linear_operator_lower_triangular - -__all__ = [] - - -def add_operators(operators, - operator_name=None, - addition_tiers=None, - name=None): - """Efficiently add one or more linear operators. - - Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of - operators `[B1, B2,...]` such that - - ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` - - The operators `Bk` result by adding some of the `Ak`, as allowed by - `addition_tiers`. - - Example of efficient adding of diagonal operators. - - ```python - A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") - A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") - - # Use two tiers, the first contains an Adder that returns Diag. Since both - # A1 and A2 are Diag, they can use this Adder. The second tier will not be - # used. - addition_tiers = [ - [_AddAndReturnDiag()], - [_AddAndReturnMatrix()]] - B_list = add_operators([A1, A2], addition_tiers=addition_tiers) - - len(B_list) - ==> 1 - - B_list[0].__class__.__name__ - ==> 'LinearOperatorDiag' - - B_list[0].to_dense() - ==> [[3., 0.], - [0., 3.]] - - B_list[0].name - ==> 'Add/A1__A2/' - ``` - - Args: - operators: Iterable of `LinearOperator` objects with same `dtype`, domain - and range dimensions, and broadcastable batch shapes. - operator_name: String name for returned `LinearOperator`. Defaults to - concatenation of "Add/A__B/" that indicates the order of addition steps. - addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` - is a list of `Adder` objects. This function attempts to do all additions - in tier `i` before trying tier `i + 1`. - name: A name for this `Op`. Defaults to `add_operators`. - - Returns: - Subclass of `LinearOperator`. Class and order of addition may change as new - (and better) addition strategies emerge. - - Raises: - ValueError: If `operators` argument is empty. - ValueError: If shapes are incompatible. - """ - # Default setting - if addition_tiers is None: - addition_tiers = _DEFAULT_ADDITION_TIERS - - # Argument checking. - check_ops.assert_proper_iterable(operators) - operators = list(reversed(operators)) - if len(operators) < 1: - raise ValueError( - "Argument 'operators' must contain at least one operator. " - "Found: %s" % operators) - if not all( - isinstance(op, linear_operator.LinearOperator) for op in operators): - raise TypeError( - "Argument 'operators' must contain only LinearOperator instances. " - "Found: %s" % operators) - _static_check_for_same_dimensions(operators) - _static_check_for_broadcastable_batch_shape(operators) - - graph_parents = [] - for operator in operators: - graph_parents.extend(operator.graph_parents) - - with ops.name_scope(name or "add_operators", values=graph_parents): - - # Additions done in one of the tiers. Try tier 0, 1,... - ops_to_try_at_next_tier = list(operators) - for tier in addition_tiers: - ops_to_try_at_this_tier = ops_to_try_at_next_tier - ops_to_try_at_next_tier = [] - while ops_to_try_at_this_tier: - op1 = ops_to_try_at_this_tier.pop() - op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) - if op2 is not None: - # Will try to add the result of this again at this same tier. - new_operator = adder.add(op1, op2, operator_name) - ops_to_try_at_this_tier.append(new_operator) - else: - ops_to_try_at_next_tier.append(op1) - - return ops_to_try_at_next_tier - - -def _pop_a_match_at_tier(op1, operator_list, tier): - # Search from the back of list to the front in order to create nice default - # order of operations. - for i in range(1, len(operator_list) + 1): - op2 = operator_list[-i] - for adder in tier: - if adder.can_add(op1, op2): - return operator_list.pop(-i), adder - return None, None - - -def _infer_hints_allowing_override(op1, op2, hints): - """Infer hints from op1 and op2. hints argument is an override. - - Args: - op1: LinearOperator - op2: LinearOperator - hints: _Hints object holding "is_X" boolean hints to use for returned - operator. - If some hint is None, try to set using op1 and op2. If the - hint is provided, ignore op1 and op2 hints. This allows an override - of previous hints, but does not allow forbidden hints (e.g. you still - cannot say a real diagonal operator is not self-adjoint. - - Returns: - _Hints object. - """ - hints = hints or _Hints() - # If A, B are self-adjoint, then so is A + B. - if hints.is_self_adjoint is None: - is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint - else: - is_self_adjoint = hints.is_self_adjoint - - # If A, B are positive definite, then so is A + B. - if hints.is_positive_definite is None: - is_positive_definite = op1.is_positive_definite and op2.is_positive_definite - else: - is_positive_definite = hints.is_positive_definite - - # A positive definite operator is always non-singular. - if is_positive_definite and hints.is_positive_definite is None: - is_non_singular = True - else: - is_non_singular = hints.is_non_singular - - return _Hints( - is_non_singular=is_non_singular, - is_self_adjoint=is_self_adjoint, - is_positive_definite=is_positive_definite) - - -def _static_check_for_same_dimensions(operators): - """ValueError if operators determined to have different dimensions.""" - if len(operators) < 2: - return - - domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators - if op.domain_dimension.value is not None] - if len(set(value for name, value in domain_dimensions)) > 1: - raise ValueError("Operators must have the same domain dimension. Found: %s" - % domain_dimensions) - - range_dimensions = [(op.name, op.range_dimension.value) for op in operators - if op.range_dimension.value is not None] - if len(set(value for name, value in range_dimensions)) > 1: - raise ValueError("Operators must have the same range dimension. Found: %s" % - range_dimensions) - - -def _static_check_for_broadcastable_batch_shape(operators): - """ValueError if operators determined to have non-broadcastable shapes.""" - if len(operators) < 2: - return - - # This will fail if they cannot be broadcast together. - batch_shape = operators[0].batch_shape - for op in operators[1:]: - batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) - - -class _Hints(object): - """Holds 'is_X' flags that every LinearOperator is initialized with.""" - - def __init__(self, - is_non_singular=None, - is_positive_definite=None, - is_self_adjoint=None): - self.is_non_singular = is_non_singular - self.is_positive_definite = is_positive_definite - self.is_self_adjoint = is_self_adjoint - - -################################################################################ -# Classes to add two linear operators. -################################################################################ - - -@six.add_metaclass(abc.ABCMeta) -class _Adder(object): - """Abstract base class to add two operators. - - Each `Adder` acts independently, adding everything it can, paying no attention - as to whether another `Adder` could have done the addition more efficiently. - """ - - @property - def name(self): - return self.__class__.__name__ - - @abc.abstractmethod - def can_add(self, op1, op2): - """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" - pass - - @abc.abstractmethod - def _add(self, op1, op2, operator_name, hints): - # Derived classes can assume op1 and op2 have been validated, e.g. they have - # the same dtype, and their domain/range dimensions match. - pass - - def add(self, op1, op2, operator_name, hints=None): - """Return new `LinearOperator` acting like `op1 + op2`. - - Args: - op1: `LinearOperator` - op2: `LinearOperator`, with `shape` and `dtype` such that adding to - `op1` is allowed. - operator_name: `String` name to give to returned `LinearOperator` - hints: `_Hints` object. Returned `LinearOperator` will be created with - these hints. - - Returns: - `LinearOperator` - """ - updated_hints = _infer_hints_allowing_override(op1, op2, hints) - - if operator_name is None: - operator_name = "Add/" + op1.name + "__" + op2.name + "/" - - values = op1.graph_parents + op2.graph_parents - scope_name = self.name - if scope_name.startswith("_"): - scope_name = scope_name[1:] - with ops.name_scope(scope_name, values=values): - return self._add(op1, op2, operator_name, updated_hints) - - -class _AddAndReturnScaledIdentity(_Adder): - """Handles additions resulting in an Identity family member. - - The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family - is closed under addition. This `Adder` respects that, and returns an Identity - """ - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_IDENTITY_FAMILY) - - def _add(self, op1, op2, operator_name, hints): - # Will build a LinearOperatorScaledIdentity. - - if _type(op1) == _SCALED_IDENTITY: - multiplier_1 = op1.multiplier - else: - multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) - - if _type(op2) == _SCALED_IDENTITY: - multiplier_2 = op2.multiplier - else: - multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) - - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=op1.range_dimension_tensor(), - multiplier=multiplier_1 + multiplier_2, - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnDiag(_Adder): - """Handles additions resulting in a Diag operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE) - - def _add(self, op1, op2, operator_name, hints): - return linear_operator_diag.LinearOperatorDiag( - diag=op1.diag_part() + op2.diag_part(), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnTriL(_Adder): - """Handles additions resulting in a TriL operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE.union({_TRIL})) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnMatrix(_Adder): - """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" - - def can_add(self, op1, op2): # pylint: disable=unused-argument - return isinstance(op1, linear_operator.LinearOperator) and isinstance( - op2, linear_operator.LinearOperator) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - return linear_operator_full_matrix.LinearOperatorFullMatrix( - matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -################################################################################ -# Constants designating types of LinearOperators -################################################################################ - -# Type name constants for LinearOperator classes. -_IDENTITY = "identity" -_SCALED_IDENTITY = "scaled_identity" -_DIAG = "diag" -_TRIL = "tril" -_MATRIX = "matrix" - -# Groups of operators. -_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} -_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} -# operators with an efficient .add_to_tensor() method. -_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE - - -def _type(operator): - """Returns the type name constant (e.g. _TRIL) for operator.""" - if isinstance(operator, linear_operator_diag.LinearOperatorDiag): - return _DIAG - if isinstance(operator, - linear_operator_lower_triangular.LinearOperatorLowerTriangular): - return _TRIL - if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): - return _MATRIX - if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): - return _IDENTITY - if isinstance(operator, - linear_operator_identity.LinearOperatorScaledIdentity): - return _SCALED_IDENTITY - raise TypeError("Operator type unknown: %s" % operator) - - -################################################################################ -# Addition tiers: -# We attempt to use Adders in tier K before K+1. -# -# Organize tiers to -# (i) reduce O(..) complexity of forming final operator, and -# (ii) produce the "most efficient" final operator. -# Dev notes: -# * Results of addition at tier K will be added at tier K or higher. -# * Tiers may change, and we warn the user that it may change. -################################################################################ - -# Note that the final tier, _AddAndReturnMatrix, will convert everything to a -# dense matrix. So it is sometimes very inefficient. -_DEFAULT_ADDITION_TIERS = [ - [_AddAndReturnScaledIdentity()], - [_AddAndReturnDiag()], - [_AddAndReturnTriL()], - [_AddAndReturnMatrix()], -] diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 1d2db1cec8f28c1d7b991ec9639086eb81dc32b9..8466dc36d13e223aed4f1dfe8e39a6f91c99fa55 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -125,7 +125,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): ], example_ids=[str(i) for i in range(num_examples)]) - weights = variables_lib.Variable( + weights = variables_lib.VariableV1( array_ops.zeros([dim], dtype=dtypes.float32)) variables_dict = dict( sparse_features_weights=[weights], @@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender, partitioned=False): +def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. partitioner = None @@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False): partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, axis=0) with variable_scope.variable_scope( - name_or_scope='variables', + name_or_scope=('variables/shard_{}'.format(num_shards) + if num_shards else 'variables'), partitioner=partitioner): - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + age_weights = variable_scope.get_variable( + name='age', + initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32)) + gender_weights = variable_scope.get_variable( + name='gender', + initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -183,7 +184,7 @@ def make_dense_examples_and_variables_dicts(dense_features_values, weights, dense_tensors.append(dense_tensor) # Add variables of shape [feature_column_dimension]. dense_weights.append( - variables_lib.Variable( + variables_lib.VariableV1( array_ops.zeros( [dense_tensor.get_shape().as_list()[1]], dtype=dtypes.float32))) @@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1, partitioned=True) + variables = make_variable_dict(1, 1, num_shards, partitioned=True) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -322,6 +323,68 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSomePartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [0], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + # Explicitly make age a [1]-shaped Variable (which cannot be + # partitioned), while making gender a PartitionedVariable. + age_weights = variables_lib.VariableV1( + array_ops.zeros([1], dtype=dtypes.float32)) + with variable_scope.variable_scope( + name_or_scope=('variables/shard_{}'.format(num_shards) + if num_shards else 'variables'), + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)): + gender_weights = variable_scope.get_variable( + name='gender', + initializer=array_ops.zeros([2], dtype=dtypes.float32)) + variables = dict( + sparse_features_weights=[age_weights, gender_weights], + dense_features_weights=[]) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.593014 is the unregularized_loss at that optimum. + self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.593014, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): dim = 20 num_examples = 1000 @@ -463,7 +526,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=0, symmetric_l1_regularization=0, @@ -521,7 +584,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): with self._single_threaded_test_session(): # Only use examples 0 and 2 examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -561,7 +624,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -598,7 +661,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(3, 1) + variables = make_variable_dict(3, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -639,7 +702,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -679,7 +742,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): for num_shards in _SHARD_NUMBERS: with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) + variables = make_variable_dict(1, 1, num_shards) options = dict( symmetric_l2_regularization=1, symmetric_l1_regularization=0, @@ -738,7 +801,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): labels=[1.0, 0.0]) # Replace with a variable of size 1 instead of 2. variables['dense_features_weights'] = [ - variables_lib.Variable(array_ops.zeros( + variables_lib.VariableV1(array_ops.zeros( [1], dtype=dtypes.float32)) ] options = dict( diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 14f59a3f64e5eb91c9754497620b137aae51ad81..b5099a0bf6d4c5425c6f8316e74ae835cf016592 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,6 +22,7 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -151,7 +152,8 @@ class SdcaModel(object): default_value=[0.0, 0.0, 0.0, 0.0], # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe # empty_key (that will never collide with actual payloads). - empty_key=[0, 0]) + empty_key=[0, 0], + deleted_key=[1, 1]) summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) summary.scalar('examples_seen', self._hashtable.size()) @@ -400,14 +402,16 @@ class SdcaModel(object): sparse_weights = [] sparse_indices = [] - # If we have partitioned variables, keep a few lists of Tensors around - # that we need for the assign_add after the op call to - # gen_sdca_ops.sdca_optimizer(). - num_partitions_by_var = [] - p_assignments_by_var = [] - gather_ids_by_var = [] - for w, i in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_feature_indices): + # If we have partitioned variables, keep a few dictionaries of Tensors + # around that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). These are keyed because we may have a + # mix of partitioned and un-partitioned variables. + num_partitions_by_var = {} + p_assignments_by_var = {} + gather_ids_by_var = {} + for v_num, (w, i) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices)): # Append the sparse_indices (in full-variable space). sparse_idx = math_ops.cast( array_ops.unique(math_ops.cast(i, dtypes.int32))[0], @@ -456,10 +460,10 @@ class SdcaModel(object): gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, num_partitions) - # Append these to the lists for use in the later update. - num_partitions_by_var.append(num_partitions) - p_assignments_by_var.append(p_assignments) - gather_ids_by_var.append(gather_ids) + # Add these into the dictionaries for use in the later update. + num_partitions_by_var[v_num] = num_partitions + p_assignments_by_var[v_num] = p_assignments + gather_ids_by_var[v_num] = gather_ids # Gather the weights from each partition. partition_gathered_weights = [] @@ -483,24 +487,44 @@ class SdcaModel(object): sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access - esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( - sparse_example_indices, - sparse_feature_indices, - sparse_features_values, - self._convert_n_to_tensor(self._examples['dense_features']), - internal_convert_to_tensor(self._examples['example_weights']), - internal_convert_to_tensor(self._examples['example_labels']), - sparse_indices, - sparse_weights, - self._convert_n_to_tensor(self._slots[ - 'unshrinked_dense_features_weights']), - example_state_data, - loss_type=self._options['loss_type'], - l1=self._options['symmetric_l1_regularization'], - l2=self._symmetric_l2_regularization(), - num_loss_partitions=self._num_loss_partitions(), - num_inner_iterations=1, - adaptative=self._adaptive()) + if compat.forward_compatible(year=2018, month=10, day=30): + esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2( + sparse_example_indices, + sparse_feature_indices, + sparse_features_values, + self._convert_n_to_tensor(self._examples['dense_features']), + internal_convert_to_tensor(self._examples['example_weights']), + internal_convert_to_tensor(self._examples['example_labels']), + sparse_indices, + sparse_weights, + self._convert_n_to_tensor(self._slots[ + 'unshrinked_dense_features_weights']), + example_state_data, + loss_type=self._options['loss_type'], + l1=self._options['symmetric_l1_regularization'], + l2=self._symmetric_l2_regularization(), + num_loss_partitions=self._num_loss_partitions(), + num_inner_iterations=1, + adaptive=self._adaptive()) + else: + esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( + sparse_example_indices, + sparse_feature_indices, + sparse_features_values, + self._convert_n_to_tensor(self._examples['dense_features']), + internal_convert_to_tensor(self._examples['example_weights']), + internal_convert_to_tensor(self._examples['example_labels']), + sparse_indices, + sparse_weights, + self._convert_n_to_tensor(self._slots[ + 'unshrinked_dense_features_weights']), + example_state_data, + loss_type=self._options['loss_type'], + l1=self._options['symmetric_l1_regularization'], + l2=self._symmetric_l2_regularization(), + num_loss_partitions=self._num_loss_partitions(), + num_inner_iterations=1, + adaptative=self._adaptive()) # pylint: enable=protected-access with ops.control_dependencies([esu]): diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 5015fb0848107950dd27eb81431dd308f22858bc..44a869f7c2745c594b6a4ea69a2a9e6f1b4f780a 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -48,6 +48,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype, default_value, empty_key, + deleted_key, num_shards=1, checkpoint=True, name='ShardedMutableHashTable'): @@ -62,6 +63,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + deleted_key=deleted_key, checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 553b116a3b3d76423d4700691fb6912101bebca4..2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -33,6 +33,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = -1 empty_key = 0 + deleted_key = -1 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) table = ShardedMutableDenseHashTable( @@ -40,6 +41,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -56,6 +58,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): with self.cached_session(): default_val = [-0.1, 0.2] empty_key = [0, 1] + deleted_key = [1, 0] keys = constant_op.constant([[11, 12], [13, 14], [15, 16]], dtypes.int64) values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], @@ -65,6 +68,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.float32, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) @@ -81,6 +85,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): def testExportSharded(self): with self.cached_session(): empty_key = -2 + deleted_key = -3 default_val = -1 num_shards = 2 keys = constant_op.constant([10, 11, 12], dtypes.int64) @@ -90,6 +95,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase): dtypes.int64, default_val, empty_key, + deleted_key, num_shards=num_shards) self.assertAllEqual(0, table.size().eval()) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index f320b53d940d3f2b76975fb1302aaf344e785aca..787a85644c35c807df84f74cbce06f80fd0b004d 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -4,6 +4,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") exports_files(glob([ @@ -26,6 +27,14 @@ config_setting( }, ) +# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate. +# WARNING: This build flag is experimental and subject to change. +config_setting( + name = "with_tflite_flex", + define_values = {"with_tflite_flex": "true"}, + visibility = ["//visibility:public"], +) + cc_library( name = "schema_fbs_version", hdrs = ["version.h"], @@ -180,7 +189,12 @@ cc_library( "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/profiling:profiler", "//tensorflow/contrib/lite/schema:schema_fbs", - ], + ] + select({ + ":with_tflite_flex": [ + "//tensorflow/contrib/lite/delegates/flex:delegate", + ], + "//conditions:default": [], + }), ) cc_library( @@ -259,6 +273,7 @@ cc_test( "testdata/0_subgraphs.bin", "testdata/2_subgraphs.bin", "testdata/empty_model.bin", + "testdata/multi_add_flex.bin", "testdata/test_model.bin", "testdata/test_model_broken.bin", ], @@ -266,6 +281,26 @@ cc_test( ":framework", "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework with the flex library linked into the target. +tf_cc_test( + name = "model_flex_test", + size = "small", + srcs = ["model_flex_test.cc"], + data = [ + "testdata/multi_add_flex.bin", + ], + tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC. + deps = [ + ":framework", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index a676b705f143b393c7e5bfa9e40d23f9adb68dcc..a4b3d83efe09358cb8e7a5f673a96f28faa84d08 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -4,5 +4,5 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. -See the documentation: https://www.tensorflow.org/mobile/tflite/ -Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite) +See the documentation: https://www.tensorflow.org/lite/ +Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/) diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 5c705ea53b23732ae76820890dd31bd49a5daae6..f962a138f712d8988c24a97439ae0233ac0c1a31 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -212,7 +212,8 @@ def json_to_tflite(name, src, out): # This is the master list of generated examples that will be made into tests. A # function called make_XXX_tests() must also appear in generate_examples.py. -# Disable a test by commenting it out. If you do, add a link to a bug or issue. +# Disable a test by adding it to the blacklists specified in +# generated_test_models_failing(). def generated_test_models(): return [ "add", @@ -291,31 +292,63 @@ def generated_test_models(): "tile", "topk", "transpose", - #"transpose_conv", # disabled due to b/111213074 + "transpose_conv", "unpack", "where", + "zeros_like", ] +# List of models that fail generated tests for the conversion mode. +# If you have to disable a test, please add here with a link to the appropriate +# bug or issue. +def generated_test_models_failing(conversion_mode): + if not conversion_mode: + return [ + "transpose_conv", # disabled due to b/111213074 + ] + + if conversion_mode == "toco-flex": + # TODO(b/117328698): Fix and enable the known flex failures. + return [ + "lstm", + "split", + "unpack", + ] + + return [] + def generated_test_conversion_modes(): """Returns a list of conversion modes.""" # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050. - return ["toco-extended", ""] + return ["toco-flex", ""] def generated_test_models_all(): """Generates a list of all tests with the different converters. Returns: - List of tuples representing (conversion mode, name of test). + List of tuples representing: + (conversion mode, name of test, test tags, test args). """ conversion_modes = generated_test_conversion_modes() tests = generated_test_models() options = [] for conversion_mode in conversion_modes: + failing_tests = generated_test_models_failing(conversion_mode) for test in tests: + tags = [] + args = [] + if test in failing_tests: + tags.append("notap") + tags.append("manual") if conversion_mode: test += "_%s" % conversion_mode - options.append((conversion_mode, test)) + + # Flex conversion shouldn't suffer from the same conversion bugs + # listed for the default TFLite kernel backend. + if conversion_mode == "toco-flex": + args.append("--ignore_known_bugs=false") + options.append((conversion_mode, test, tags, args)) return options def gen_zip_test(name, test_name, conversion_mode, **kwargs): @@ -334,14 +367,7 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs): # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050. # if conversion_mode == "pb2lite": # toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite" - flags = "--ignore_toco_errors --run_with_extended" - kwargs["tags"].append("skip_already_failing") - kwargs["tags"].append("no_oss") - - # TODO(b/115504899): Re-enable asan, msan and tsan tests. - kwargs["tags"].append("noasan") - kwargs["tags"].append("nomsan") - kwargs["tags"].append("notsan") + flags = "--ignore_toco_errors --run_with_flex" gen_zipped_test_file( name = "zip_%s" % test_name, @@ -394,3 +420,42 @@ def gen_selected_ops(name, model): (tool, model, out, tflite_path[2:]), tools = [tool], ) + +def gen_full_model_test(conversion_modes, models, data, test_suite_tag): + """Generates Python test targets for testing TFLite models. + + Args: + conversion_modes: List of conversion modes to test the models on. + models: List of models to test. + data: List of BUILD targets linking the data. + test_suite_tag: Tag identifying the model test suite. + """ + options = [ + (conversion_mode, model) + for model in models + for conversion_mode in conversion_modes + ] + + for conversion_mode, model_name in options: + native.py_test( + name = "model_coverage_test_%s_%s" % (model_name, conversion_mode.lower()), + srcs = ["model_coverage_test.py"], + main = "model_coverage_test.py", + args = [ + "--model_name=%s" % model_name, + "--converter_mode=%s" % conversion_mode, + ], + data = data, + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_windows", + "notap", + "manual", + ] + [test_suite_tag], + deps = [ + "//tensorflow/contrib/lite/testing:model_coverage_lib", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/python:client_testlib", + ], + ) diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 5e97b777fc6cd90e2fe71ad203fecebe912c25d7..6117cbf9f15074766f90971dbf695eaa0b4678e3 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -118,6 +118,9 @@ typedef enum { kTfLiteBuiltinFloorDiv = 90, kTfLiteBuiltinReduceAny = 91, kTfLiteBuiltinSquare = 92, + kTfLiteBuiltinZerosLike = 93, + kTfLiteBuiltinFill = 94, + kTfLiteBuiltinFloorMod = 95, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index fa43e6a0244b76e388a4f8b583c06572d9efa16b..1e65c3cee27798990eb9888e67306c6285925a1f 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -25,6 +25,9 @@ extern "C" { // TODO(aselle): Consider using "if this then that" for testing. +// IMPORTANT: All new members of structs must be added at the end to ensure +// backwards compatibility. + // Possible padding types (for convolutions) typedef enum { kTfLitePaddingUnknown = 0, @@ -71,11 +74,15 @@ typedef struct { } TfLitePoolParams; typedef struct { + // Parameters for DepthwiseConv version 1 or above. TfLitePadding padding; int stride_width; int stride_height; int depth_multiplier; TfLiteFusedActivation activation; + // Parameters for DepthwiseConv version 2 or above. + int dilation_width_factor; + int dilation_height_factor; } TfLiteDepthwiseConvParams; typedef struct { @@ -92,6 +99,12 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; +} TfLiteBidirectionalSequenceRNNParams; + typedef enum { kTfLiteFullyConnectedWeightsFormatDefault = 0, kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, @@ -173,6 +186,23 @@ typedef struct { TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; +typedef struct { + // Parameters for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteUnidirectionalSequenceLSTMParams; + +typedef struct { + // Parameters for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; +} TfLiteBidirectionalSequenceLSTMParams; + typedef struct { bool align_corners; } TfLiteResizeBilinearParams; diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc index 4d0ba75e68367c9a0a7a7c9c3ac1ea14a875c201..ba458b4252c53ebc91adcd0afbd16f783037dd42 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -73,6 +73,8 @@ TEST(IntArray, CanCompileStructs) { TfLiteFakeQuantParams fake_quant_params; TfLitePackParams pack_params; TfLiteOneHotParams one_hot_params; + TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params; + TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params; } } // namespace tflite diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c index 1846bad4b742c26f0ed774af01f274b2abdc741a..8a0c177b1948df9b98e68f6cc6f44628ea8407a3 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.c +++ b/tensorflow/contrib/lite/c/c_api_internal.c @@ -14,15 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/c/c_api_internal.h" +#ifndef TF_LITE_STATIC_MEMORY #include #include #include +#endif // TF_LITE_STATIC_MEMORY int TfLiteIntArrayGetSizeInBytes(int size) { static TfLiteIntArray dummy; return sizeof(dummy) + sizeof(dummy.data[0]) * size; } +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +#ifndef TF_LITE_STATIC_MEMORY + TfLiteIntArray* TfLiteIntArrayCreate(int size) { TfLiteIntArray* ret = (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size)); @@ -40,16 +54,6 @@ void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { printf("]\n"); } -int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { - if (a == b) return 1; - if (a == NULL || b == NULL) return 0; - if (a->size != b->size) return 0; - int i = 0; - for (; i < a->size; i++) - if (a->data[i] != b->data[i]) return 0; - return 1; -} - TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { if (!src) return NULL; TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); @@ -102,3 +106,4 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { } tensor->bytes = num_bytes; } +#endif // TF_LITE_STATIC_MEMORY diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h index 34c874d1d27605ad3f241890ca7cfa93a855be64..ee3dff6792a33a575e75fe7a1ef3dc7985be9c1d 100644 --- a/tensorflow/contrib/lite/c/c_api_internal.h +++ b/tensorflow/contrib/lite/c/c_api_internal.h @@ -146,7 +146,7 @@ void TfLiteIntArrayFree(TfLiteIntArray* v); #define TF_LITE_ENSURE_OK(context, status) \ do { \ if ((status) != kTfLiteOk) { \ - return status; \ + return kTfLiteError; \ } \ } while (0) diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index eef4b6d831c493cf542d5f187536d7d9e5446f7c..890d9c04bb372fce2b86d503ec1f346a302786f2 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -44,16 +44,6 @@ void FlatBufferIntVectorToArray(int max_size_of_buffer, } } -// Allocate a structure using malloc, but make sure the structure is a POD -// structure that doesn't require constructors to run. The reason we do this, -// is that Interpreter's C extension part will take ownership so destructors -// will not be run during deallocation. -template -T* MallocPOD() { - static_assert(std::is_pod::value, "Builtin data structure must be POD."); - return static_cast(malloc(sizeof(T))); -} - } // namespace TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, @@ -98,7 +88,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, // need to be released by calling `free`.` // If it returns kTfLiteError, `builtin_data` will be `nullptr`. TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, - ErrorReporter* error_reporter, void** builtin_data) { + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data) { auto parse_padding = [](Padding padding) { switch (padding) { case Padding_SAME: @@ -150,7 +141,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { case BuiltinOperator_CONV_2D: { - TfLiteConvParams* params = MallocPOD(); + TfLiteConvParams* params = allocator->AllocatePOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { params->padding = parse_padding(conv_params->padding()); params->stride_width = conv_params->stride_w(); @@ -165,7 +156,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_CAST: { - TfLiteCastParams* params = MallocPOD(); + TfLiteCastParams* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { auto in_status = ConvertTensorType(schema_params->in_data_type(), @@ -174,7 +165,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, ConvertTensorType(schema_params->out_data_type(), ¶ms->out_data_type, error_reporter); if (in_status != kTfLiteOk || out_status != kTfLiteOk) { - free(params); + allocator->Deallocate(params); return kTfLiteError; } } @@ -183,7 +174,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { params->type = parseLSHProjectionType(lshParams->type()); } @@ -193,7 +184,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_AVERAGE_POOL_2D: case BuiltinOperator_MAX_POOL_2D: case BuiltinOperator_L2_POOL_2D: { - TfLitePoolParams* params = MallocPOD(); + TfLitePoolParams* params = allocator->AllocatePOD(); if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { params->padding = parse_padding(pool_params->padding()); params->stride_width = pool_params->stride_w(); @@ -208,7 +199,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_DEPTHWISE_CONV_2D: { TfLiteDepthwiseConvParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { params->padding = parse_padding(conv_params->padding()); params->stride_width = conv_params->stride_w(); @@ -216,12 +207,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->depth_multiplier = conv_params->depth_multiplier(); params->activation = parse_activation(conv_params->fused_activation_function()); + + params->dilation_width_factor = conv_params->dilation_w_factor(); + params->dilation_height_factor = conv_params->dilation_h_factor(); } *builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_SVDF: { - TfLiteSVDFParams* params = MallocPOD(); + TfLiteSVDFParams* params = allocator->AllocatePOD(); if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { params->rank = svdf_params->rank(); params->activation = @@ -230,9 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = MallocPOD(); + auto params = allocator->AllocatePOD(); if (auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = @@ -242,8 +235,21 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + allocator->AllocatePOD(); + if (auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = parse_activation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RNN: { - TfLiteRNNParams* params = MallocPOD(); + TfLiteRNNParams* params = allocator->AllocatePOD(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { params->activation = parse_activation(rnn_params->fused_activation_function()); @@ -253,7 +259,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* embedding_params = op->builtin_options_as_EmbeddingLookupSparseOptions()) { params->combiner = parseCombinerType(embedding_params->combiner()); @@ -263,7 +269,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_FULLY_CONNECTED: { TfLiteFullyConnectedParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* fully_connected_params = op->builtin_options_as_FullyConnectedOptions()) { params->activation = parse_activation( @@ -288,7 +294,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, // no-op. break; case BuiltinOperator_SOFTMAX: { - TfLiteSoftmaxParams* params = MallocPOD(); + TfLiteSoftmaxParams* params = + allocator->AllocatePOD(); if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { params->beta = softmax_params->beta(); } @@ -297,7 +304,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_CONCATENATION: { TfLiteConcatenationParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* concatenation_params = op->builtin_options_as_ConcatenationOptions()) { params->activation = @@ -308,7 +315,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_MUL: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_MulOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); @@ -317,7 +324,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_ADD: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_AddOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); @@ -326,7 +333,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_DIV: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_DivOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); @@ -335,7 +342,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SUB: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_SubOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); @@ -344,7 +351,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_L2_NORMALIZATION: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { params->activation = parse_activation(schema_params->fused_activation_function()); @@ -353,7 +360,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_LocalResponseNormalizationOptions()) { params->radius = schema_params->radius(); @@ -364,10 +371,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = MallocPOD(); + auto params = allocator->AllocatePOD(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = parse_activation(lstm_params->fused_activation_function()); @@ -385,8 +390,36 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto* params = + allocator->AllocatePOD(); + if (auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + } + *builtin_data = reinterpret_cast(params); + break; + } + + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + allocator->AllocatePOD(); + if (auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RESIZE_BILINEAR: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ResizeBilinearOptions()) { params->align_corners = schema_params->align_corners(); @@ -395,7 +428,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_RESHAPE: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, @@ -406,7 +439,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SKIP_GRAM: { - TfLiteSkipGramParams* params = MallocPOD(); + TfLiteSkipGramParams* params = + allocator->AllocatePOD(); if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { params->ngram_size = skip_gram_params->ngram_size(); params->max_skip_size = skip_gram_params->max_skip_size(); @@ -416,7 +450,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SPACE_TO_DEPTH: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { params->block_size = schema_params->block_size(); } @@ -424,7 +458,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_GATHER: { - TfLiteGatherParams* params = MallocPOD(); + TfLiteGatherParams* params = allocator->AllocatePOD(); params->axis = 0; if (auto* gather_params = op->builtin_options_as_GatherOptions()) { params->axis = gather_params->axis(); @@ -439,7 +473,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_REDUCE_PROD: case BuiltinOperator_REDUCE_ANY: case BuiltinOperator_SUM: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } @@ -447,7 +481,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SPLIT: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_SplitOptions()) { params->num_splits = schema_params->num_splits(); } @@ -455,7 +489,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SQUEEZE: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { const auto& squeeze_dims = schema_params->squeeze_dims(); FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, @@ -466,7 +500,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_STRIDED_SLICE: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { params->begin_mask = schema_params->begin_mask(); params->end_mask = schema_params->end_mask(); @@ -478,7 +512,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_ARG_MAX: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { ConvertTensorType(schema_params->output_type(), ¶ms->output_type, error_reporter); @@ -487,7 +521,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_ARG_MIN: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { ConvertTensorType(schema_params->output_type(), ¶ms->output_type, error_reporter); @@ -497,7 +531,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* transpose_conv_params = op->builtin_options_as_TransposeConvOptions()) { params->padding = parse_padding(transpose_conv_params->padding()); @@ -509,7 +543,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, } case BuiltinOperator_SPARSE_TO_DENSE: { TfLiteSparseToDenseParams* params = - MallocPOD(); + allocator->AllocatePOD(); if (auto* sparse_to_dense_params = op->builtin_options_as_SparseToDenseOptions()) { params->validate_indices = sparse_to_dense_params->validate_indices(); @@ -518,7 +552,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_SHAPE: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { ConvertTensorType(schema_params->out_type(), ¶ms->out_type, error_reporter); @@ -527,7 +561,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_PACK: { - TfLitePackParams* params = MallocPOD(); + TfLitePackParams* params = allocator->AllocatePOD(); if (auto* pack_params = op->builtin_options_as_PackOptions()) { params->values_count = pack_params->values_count(); params->axis = pack_params->axis(); @@ -541,7 +575,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, return kTfLiteError; } case BuiltinOperator_FAKE_QUANT: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { params->min = schema_params->min(); params->max = schema_params->max(); @@ -552,7 +586,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_ONE_HOT: { - auto* params = MallocPOD(); + auto* params = allocator->AllocatePOD(); if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { params->axis = schema_params->axis(); } @@ -560,7 +594,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_UNPACK: { - TfLiteUnpackParams* params = MallocPOD(); + TfLiteUnpackParams* params = allocator->AllocatePOD(); if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { params->num = unpack_params->num(); params->axis = unpack_params->axis(); @@ -615,6 +649,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOGICAL_NOT: case BuiltinOperator_FLOOR_DIV: case BuiltinOperator_SQUARE: + case BuiltinOperator_ZEROS_LIKE: + case BuiltinOperator_FILL: + case BuiltinOperator_FLOOR_MOD: break; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h index 4dec6f9cfcf9de5b2a0487b858e3baaec0d46021..c770e627fd572dc252c6261bd3713d3105d225f1 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h @@ -26,6 +26,25 @@ limitations under the License. namespace tflite { +// Interface class for builtin data allocations. +class BuiltinDataAllocator { + public: + virtual void* Allocate(size_t size) = 0; + virtual void Deallocate(void* data) = 0; + + // Allocate a structure, but make sure it is a POD structure that doesn't + // require constructors to run. The reason we do this, is that Interpreter's C + // extension part will take ownership so destructors will not be run during + // deallocation. + template + T* AllocatePOD() { + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + return static_cast(this->Allocate(sizeof(T))); + } + + virtual ~BuiltinDataAllocator() {} +}; + // Parse the appropriate data out of the op. // // This handles builtin data explicitly as there are flatbuffer schemas. @@ -36,7 +55,8 @@ namespace tflite { // function's responsibility to free it. // If it returns kTfLiteError, `builtin_data` will be `nullptr`. TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, - ErrorReporter* error_reporter, void** builtin_data); + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, void** builtin_data); // Converts the tensor data type used in the flat buffer to the representation // used by the runtime. diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc index b12bdf43b20c5af3fb28e2afcebe23b957af6333..8ae94e1d330c1958b857cff0b44c38108f153550 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc @@ -39,11 +39,31 @@ class MockErrorReporter : public ErrorReporter { int buffer_size_; }; +// Used to determine how the op data parsing function creates its working space. +class MockDataAllocator : public BuiltinDataAllocator { + public: + MockDataAllocator() : is_allocated_(false) {} + void* Allocate(size_t size) override { + EXPECT_FALSE(is_allocated_); + const int max_size = kBufferSize; + EXPECT_LE(size, max_size); + is_allocated_ = true; + return buffer_; + } + void Deallocate(void* data) override { is_allocated_ = false; } + + private: + static constexpr int kBufferSize = 1024; + char buffer_[kBufferSize]; + bool is_allocated_; +}; + } // namespace TEST(FlatbufferConversions, TestParseOpDataConv) { MockErrorReporter mock_reporter; ErrorReporter* reporter = &mock_reporter; + MockDataAllocator mock_allocator; flatbuffers::FlatBufferBuilder builder; flatbuffers::Offset conv_options = @@ -58,7 +78,7 @@ TEST(FlatbufferConversions, TestParseOpDataConv) { const Operator* conv_op = flatbuffers::GetRoot(conv_pointer); void* output_data = nullptr; EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter, - &output_data)); + &mock_allocator, &output_data)); EXPECT_NE(nullptr, output_data); TfLiteConvParams* params = reinterpret_cast(output_data); EXPECT_EQ(kTfLitePaddingSame, params->padding); @@ -67,12 +87,12 @@ TEST(FlatbufferConversions, TestParseOpDataConv) { EXPECT_EQ(kTfLiteActRelu, params->activation); EXPECT_EQ(3, params->dilation_width_factor); EXPECT_EQ(4, params->dilation_height_factor); - free(output_data); } TEST(FlatbufferConversions, TestParseOpDataCustom) { MockErrorReporter mock_reporter; ErrorReporter* reporter = &mock_reporter; + MockDataAllocator mock_allocator; flatbuffers::FlatBufferBuilder builder; flatbuffers::Offset null_options; @@ -84,7 +104,7 @@ TEST(FlatbufferConversions, TestParseOpDataCustom) { const Operator* custom_op = flatbuffers::GetRoot(custom_pointer); void* output_data = nullptr; EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter, - &output_data)); + &mock_allocator, &output_data)); EXPECT_EQ(nullptr, output_data); } diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD similarity index 94% rename from tensorflow/contrib/lite/delegates/eager/BUILD rename to tensorflow/contrib/lite/delegates/flex/BUILD index bf5d91899ca63142f69401229b9e06b27b6c2b0b..9b89ed4f849e224d36adae7c3a7581ac542d4f0f 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/flex/BUILD @@ -2,7 +2,7 @@ # This is a TF Lite delegate that is powered by TensorFlow's Eager. # package(default_visibility = [ - "//visibility:public", + "//visibility:private", ]) licenses(["notice"]) # Apache 2.0 @@ -20,7 +20,7 @@ cc_library( "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -50,6 +50,7 @@ cc_library( hdrs = [ "delegate.h", ], + visibility = ["//visibility:public"], deps = [ ":buffer_map", ":delegate_data", @@ -60,12 +61,13 @@ cc_library( "//tensorflow/contrib/lite:util", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:lib", ], }), + alwayslink = 1, ) tf_cc_test( @@ -178,7 +180,7 @@ cc_library( "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + "//tensorflow/core:android_tensorflow_lib", ], "//conditions:default": [ "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc similarity index 95% rename from tensorflow/contrib/lite/delegates/eager/buffer_map.cc rename to tensorflow/contrib/lite/delegates/flex/buffer_map.cc index e5a19c39976969a0b05b28596c6d7d5ebe7c7782..63e39196d96a176eca105e7b11107ab52fe528dd 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" namespace tflite { -namespace eager { +namespace flex { namespace { // A tensor buffer that is allocated, deallocated and populated by TF Lite. class TfLiteTensorBuffer : public tensorflow::TensorBuffer { @@ -107,5 +107,5 @@ void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { id_to_tensor_[tensor_index] = std::move(tensor); } -} // namespace eager +} // namespace flex } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/flex/buffer_map.h similarity index 86% rename from tensorflow/contrib/lite/delegates/eager/buffer_map.h rename to tensorflow/contrib/lite/delegates/flex/buffer_map.h index aaaa045840c4511f0018ad40414f8b36cea8994a..4ce886568a55773971bc0543ec973ec84c0aac1b 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ #include @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tflite { -namespace eager { +namespace flex { // Maps a TF Lite tensor index into a TensorFlow tensor. // // The TF Lite interpreter assigns integer indices to each of its tensors, but -// the Eager delegate deals in terms of TensorFlow tensors. This class maps +// the Flex delegate deals in terms of TensorFlow tensors. This class maps // from indices to tensors and allows the creation of new tensors to be // associated with a given index. class BufferMap { @@ -55,7 +55,7 @@ class BufferMap { std::map id_to_tensor_; }; -} // namespace eager +} // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc similarity index 98% rename from tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc rename to tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc index a046943e56d2b80f2670b7fc3dd57b36dc4d2425..bb80e25e8076bb95782e4137945ad1c7cd178aee 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" #include #include @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/contrib/lite/util.h" namespace tflite { -namespace eager { +namespace flex { namespace { using ::testing::ElementsAre; @@ -164,7 +164,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) { } } // namespace -} // namespace eager +} // namespace flex } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc similarity index 70% rename from tensorflow/contrib/lite/delegates/eager/delegate.cc rename to tensorflow/contrib/lite/delegates/flex/delegate.cc index 45fc158157b624ae99bd99ecfd136efcc69ca550..c72b0cf51383897ce3afec0c39ed6bfe178d88c1 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc @@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/delegate.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate.h" #include #include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" -#include "tensorflow/contrib/lite/delegates/eager/kernel.h" -#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" #include "tensorflow/contrib/lite/util.h" #include "tensorflow/core/lib/core/status.h" namespace tflite { -namespace eager { +namespace flex { namespace delegate { TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { @@ -32,7 +32,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); - // Add all custom ops starting with "Eager" to list of supported nodes. + // Add all custom ops starting with "Flex" to list of supported nodes. std::vector supported_nodes; for (int node_index : TfLiteIntArrayView(plan)) { TfLiteNode* node; @@ -40,7 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); - if (IsEagerOp(registration->custom_name)) { + if (IsFlexOp(registration->custom_name)) { supported_nodes.push_back(node_index); } } @@ -81,28 +81,37 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, } } // namespace delegate -} // namespace eager +} // namespace flex + +// Corresponding weak declaration found in lite/model.cc. +std::unique_ptr +AcquireFlexDelegate() { + return std::unique_ptr( + tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast(delegate); + }); +} -std::unique_ptr EagerDelegate::Create() { - std::unique_ptr delegate_data; - if (!eager::DelegateData::Create(&delegate_data).ok()) { +std::unique_ptr FlexDelegate::Create() { + std::unique_ptr delegate_data; + if (!flex::DelegateData::Create(&delegate_data).ok()) { fprintf(stderr, "Unable to initialize TensorFlow context.\n"); return nullptr; } - return std::unique_ptr( - new EagerDelegate(std::move(delegate_data))); + return std::unique_ptr( + new FlexDelegate(std::move(delegate_data))); } -EagerDelegate::EagerDelegate(std::unique_ptr delegate_data) +FlexDelegate::FlexDelegate(std::unique_ptr delegate_data) : TfLiteDelegate{ /*data_=*/delegate_data.get(), - /*nullptr,*/ &eager::delegate::Prepare, - /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*nullptr,*/ &flex::delegate::Prepare, + /*CopyFromBufferHandle=*/&flex::delegate::CopyFromBufferHandle, /*CopyToBufferHandle=*/nullptr, /*FreeBufferHandle=*/nullptr}, delegate_data_(std::move(delegate_data)) {} -EagerDelegate::~EagerDelegate() {} +FlexDelegate::~FlexDelegate() {} } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/flex/delegate.h similarity index 64% rename from tensorflow/contrib/lite/delegates/eager/delegate.h rename to tensorflow/contrib/lite/delegates/flex/delegate.h index 70f3c15af4af99380113e7404d7d487c2d9237cd..1017780dc75de1cd334e0cca901bbe20ddf0bf41 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/flex/delegate.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ #include "tensorflow/contrib/lite/c/c_api_internal.h" -#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" namespace tflite { @@ -24,12 +24,12 @@ namespace tflite { // Delegate that can be used to extract parts of a graph that are designed to be // executed by TensorFlow's runtime via Eager. // -// The interpreter must be constructed after the EagerDelegate and destructed -// before the EagerDelegate. This delegate may be used with multiple +// The interpreter must be constructed after the FlexDelegate and destructed +// before the FlexDelegate. This delegate may be used with multiple // interpreters, but it is *not* thread-safe. // // Usage: -// auto delegate = EagerDelegate::Create(); +// auto delegate = FlexDelegate::Create(); // ... build interpreter ... // // if (delegate) { @@ -39,21 +39,21 @@ namespace tflite { // ... run inference ... // ... destroy interpreter ... // ... destroy delegate ... -class EagerDelegate : public TfLiteDelegate { +class FlexDelegate : public TfLiteDelegate { public: // Creates a delegate that supports TF ops. // - // If the underyling TF Eager context creation fails, returns null. - static std::unique_ptr Create(); + // If the underyling TF Flex context creation fails, returns null. + static std::unique_ptr Create(); - ~EagerDelegate(); + ~FlexDelegate(); private: - explicit EagerDelegate(std::unique_ptr delegate_data); + explicit FlexDelegate(std::unique_ptr delegate_data); - std::unique_ptr delegate_data_; + std::unique_ptr delegate_data_; }; } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc similarity index 94% rename from tensorflow/contrib/lite/delegates/eager/delegate_data.cc rename to tensorflow/contrib/lite/delegates/flex/delegate_data.cc index 0fd5c976f8ca9be16f7e3c5e610573755b40c506..8f985f770cfba9fc6a7184cfdb0a35e9e6c754af 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/core/status.h" namespace tflite { -namespace eager { +namespace flex { tensorflow::Status DelegateData::Create(std::unique_ptr* data) { std::vector devices; @@ -43,5 +43,5 @@ DelegateData::DelegateData(tensorflow::EagerContext* eager_context) DelegateData::~DelegateData() {} -} // namespace eager +} // namespace flex } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/flex/delegate_data.h similarity index 78% rename from tensorflow/contrib/lite/delegates/eager/delegate_data.h rename to tensorflow/contrib/lite/delegates/flex/delegate_data.h index 772d26f44e8b5b2b962c06f42b86df29ee1c1f8d..8d75f0b0efe758074d035f0ebcf0f5f12602323b 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data.h +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.h @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ -#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" #include "tensorflow/core/common_runtime/eager/context.h" namespace tflite { -namespace eager { +namespace flex { -// Data kept by the Eager delegate for the lifetime of an Interpreter. +// Data kept by the Flex delegate for the lifetime of an Interpreter. class DelegateData { public: // Create a new DelegateData, initialized with a newly-created EagerContext. @@ -29,7 +29,7 @@ class DelegateData { ~DelegateData(); - // The EagerContext that is required for execution of Eager Ops. + // The EagerContext that is required for execution of Flex Ops. tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } // Map from TF Lite tensor index to TensorFlow tensor for a given context. @@ -46,7 +46,7 @@ class DelegateData { std::unordered_map buffer_map_; }; -} // namespace eager +} // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc similarity index 93% rename from tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc rename to tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc index def063309fb8019b2835cfd8943dd1b7a7034a14..30b10f435a23785f88e2645714a414501bc2fab9 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" #include #include @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { -namespace eager { +namespace flex { namespace { TEST(DelegateDataTest, Basic) { @@ -39,7 +39,7 @@ TEST(DelegateDataTest, Basic) { } } // namespace -} // namespace eager +} // namespace flex } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc similarity index 95% rename from tensorflow/contrib/lite/delegates/eager/delegate_test.cc rename to tensorflow/contrib/lite/delegates/flex/delegate_test.cc index 43ec5d53b81e385b2ca3a9614ccd7b8fdff2557d..1813952cef99ef10b638ade7bcfcca486b2b3b76 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc @@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/delegate.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate.h" #include #include -#include "tensorflow/contrib/lite/delegates/eager/test_util.h" +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" namespace tflite { -namespace eager { +namespace flex { namespace { using ::testing::ContainsRegex; using ::testing::ElementsAre; -class DelegateTest : public testing::EagerModelTest { +class DelegateTest : public testing::FlexModelTest { public: DelegateTest() { - delegate_ = EagerDelegate::Create(); + delegate_ = FlexDelegate::Create(); interpreter_.reset(new Interpreter(&error_reporter_)); } @@ -46,7 +46,7 @@ class DelegateTest : public testing::EagerModelTest { } private: - std::unique_ptr delegate_; + std::unique_ptr delegate_; }; TEST_F(DelegateTest, FullGraph) { @@ -236,7 +236,7 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { } } // namespace -} // namespace eager +} // namespace flex } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/flex/kernel.cc similarity index 90% rename from tensorflow/contrib/lite/delegates/eager/kernel.cc rename to tensorflow/contrib/lite/delegates/flex/kernel.cc index 274c3c082a4a4d28c891082156877ce42afc2507..e4f1aea990da97da08a3e5adf2dd70307b20fe88 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/flex/kernel.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" -#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" -#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/string.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -28,10 +28,10 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" -// Note: this is part of TF Lite's Eager delegation code which is to be +// Note: this is part of TF Lite's Flex delegation code which is to be // completed soon. -// This is the TF Lite op that is created by the eager delegate to handle +// This is the TF Lite op that is created by the flex delegate to handle // execution of a supported subgraph. The usual flow is that the delegate // informs the interpreter of supported nodes in a graph, and each supported // subgraph is replaced with one instance of this kernel. @@ -46,7 +46,7 @@ limitations under the License. // corresponding TensorFlow/Eager Op. namespace tflite { -namespace eager { +namespace flex { namespace kernel { // Controls the lifetime of tensor handles in a vector. @@ -72,11 +72,11 @@ class VectorOfHandles { // Executes the TensorFlow op given by 'op_name', with the attributes specified // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. -tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context, - BufferMap* buffer_map, const string& op_name, - const tensorflow::NodeDef& nodedef, - const std::vector& inputs, - const std::vector& outputs) { +tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context, + BufferMap* buffer_map, const string& op_name, + const tensorflow::NodeDef& nodedef, + const std::vector& inputs, + const std::vector& outputs) { const tensorflow::AttrTypeMap* attr_types; TF_RETURN_WITH_CONTEXT_IF_ERROR( tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types), @@ -258,13 +258,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Execute the TensorFlow Ops sequentially. for (const auto& node_data : op_data->nodes) { if (node_data.nodedef.op().empty()) { - context->ReportError(context, "Invalid NodeDef in Eager op '%s'", + context->ReportError(context, "Invalid NodeDef in Flex op '%s'", node_data.name.c_str()); return kTfLiteError; } auto status = - ExecuteEagerOp(eager_context, buffer_map, node_data.name, - node_data.nodedef, node_data.inputs, node_data.outputs); + ExecuteFlexOp(eager_context, buffer_map, node_data.name, + node_data.nodedef, node_data.inputs, node_data.outputs); TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); } @@ -295,5 +295,5 @@ TfLiteRegistration GetKernel() { return registration; } -} // namespace eager +} // namespace flex } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/flex/kernel.h similarity index 79% rename from tensorflow/contrib/lite/delegates/eager/kernel.h rename to tensorflow/contrib/lite/delegates/flex/kernel.h index 2478abccaa1d9466cbadad08e73ba58bb970b4a1..ac9313a37bd5a3f5e23057512f07674c44801989 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.h +++ b/tensorflow/contrib/lite/delegates/flex/kernel.h @@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ #include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { -namespace eager { +namespace flex { // Return the registration object used to initialize and execute ops that will // be delegated to TensorFlow's Eager runtime. This TF Lite op is created by -// the eager delegate to handle execution of a supported subgraph. The usual +// the flex delegate to handle execution of a supported subgraph. The usual // flow is that the delegate informs the interpreter of supported nodes in a // graph, and each supported subgraph is replaced with one instance of this // kernel. TfLiteRegistration GetKernel(); -} // namespace eager +} // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc similarity index 94% rename from tensorflow/contrib/lite/delegates/eager/kernel_test.cc rename to tensorflow/contrib/lite/delegates/flex/kernel_test.cc index 66f2226626677fa26a8c0eb2ae8ef448ed35c141..94a6f8b61ad28144f6b8d0d462338ab4176af168 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc +++ b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" #include #include -#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" -#include "tensorflow/contrib/lite/delegates/eager/test_util.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" namespace tflite { -namespace eager { +namespace flex { namespace { using ::testing::ContainsRegex; @@ -31,12 +31,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, TfLiteIntArray* size_and_nodes = ConvertVectorToTfLiteIntArray(supported_nodes); TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels( - context, eager::GetKernel(), size_and_nodes, delegate)); + context, flex::GetKernel(), size_and_nodes, delegate)); TfLiteIntArrayFree(size_and_nodes); return kTfLiteOk; } -class KernelTest : public testing::EagerModelTest { +class KernelTest : public testing::FlexModelTest { public: KernelTest() { CHECK(DelegateData::Create(&delegate_data_).ok()); @@ -167,7 +167,7 @@ TEST_F(KernelTest, WrongSetOfNodes) { ASSERT_FALSE(Invoke()); ASSERT_THAT(error_reporter().error_messages(), - ContainsRegex("Invalid NodeDef in Eager op")); + ContainsRegex("Invalid NodeDef in Flex op")); } TEST_F(KernelTest, MixedGraph) { @@ -220,7 +220,7 @@ TEST_F(KernelTest, SplitGraph) { } } // namespace -} // namespace eager +} // namespace flex } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/flex/test_util.cc similarity index 75% rename from tensorflow/contrib/lite/delegates/eager/test_util.cc rename to tensorflow/contrib/lite/delegates/flex/test_util.cc index 8584999ace893b0963debc8dad9ead4a2597fa06..69c336a01a57416bb331a897faba03ad75a38f95 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.cc +++ b/tensorflow/contrib/lite/delegates/flex/test_util.cc @@ -13,25 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/test_util.h" +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" #include "absl/memory/memory.h" -#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "tensorflow/contrib/lite/string.h" namespace tflite { -namespace eager { +namespace flex { namespace testing { -bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } +bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } -void EagerModelTest::SetShape(int tensor_index, - const std::vector& values) { +void FlexModelTest::SetShape(int tensor_index, const std::vector& values) { ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); } -std::vector EagerModelTest::GetShape(int tensor_index) { +std::vector FlexModelTest::GetShape(int tensor_index) { std::vector result; auto* dims = interpreter_->tensor(tensor_index)->dims; result.reserve(dims->size); @@ -41,13 +40,13 @@ std::vector EagerModelTest::GetShape(int tensor_index) { return result; } -TfLiteType EagerModelTest::GetType(int tensor_index) { +TfLiteType FlexModelTest::GetType(int tensor_index) { return interpreter_->tensor(tensor_index)->type; } -void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, - const std::vector& outputs, - TfLiteType type, const std::vector& dims) { +void FlexModelTest::AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, TfLiteType type, + const std::vector& dims) { interpreter_->AddTensors(num_tensors); for (int i = 0; i < num_tensors; ++i) { TfLiteQuantizationParams quant; @@ -66,8 +65,8 @@ void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); } -void EagerModelTest::AddTfLiteMulOp(const std::vector& inputs, - const std::vector& outputs) { +void FlexModelTest::AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs) { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; reg.builtin_code = BuiltinOperator_MUL; reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { @@ -90,8 +89,8 @@ void EagerModelTest::AddTfLiteMulOp(const std::vector& inputs, kTfLiteOk); } -void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, - const std::vector& outputs) { +void FlexModelTest::AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs) { auto attr = [](const string& key, const string& value) { return " attr{ key: '" + key + "' value {" + value + "}}"; }; @@ -107,28 +106,28 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, if (op == kUnpack) { string attributes = type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); - AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); + AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { string attributes = type_attribute; - AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); + AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { string attributes = type_attribute; - AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); + AddTfOp("FlexAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { string attributes = type_attribute; - AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); + AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); } else if (op == kIncompatibleNodeDef) { // "Cast" op is created without attributes - making it incompatible. - AddTfOp("EagerCast", "Cast", "", inputs, outputs); + AddTfOp("FlexCast", "Cast", "", inputs, outputs); } } -void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name, - const string& nodedef_str, - const std::vector& inputs, - const std::vector& outputs) { +void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, + const std::vector& inputs, + const std::vector& outputs) { static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; reg.builtin_code = BuiltinOperator_CUSTOM; reg.custom_name = tflite_name; @@ -154,5 +153,5 @@ void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name, } } // namespace testing -} // namespace eager +} // namespace flex } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/flex/test_util.h similarity index 90% rename from tensorflow/contrib/lite/delegates/eager/test_util.h rename to tensorflow/contrib/lite/delegates/flex/test_util.h index 816db41931610c913cf406095c308a5f7682f7cb..a8c81b90a3b8dc49ae058adb172456fe4d6e7172 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.h +++ b/tensorflow/contrib/lite/delegates/flex/test_util.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ #include "tensorflow/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/test_util.h" namespace tflite { -namespace eager { +namespace flex { namespace testing { enum TfOpType { @@ -35,12 +35,12 @@ enum TfOpType { }; // This class creates models with TF and TFLite ops. In order to use this class -// to test the Eager delegate, implement a function that calls +// to test the Flex delegate, implement a function that calls // interpreter->ModifyGraphWithDelegate. -class EagerModelTest : public ::testing::Test { +class FlexModelTest : public ::testing::Test { public: - EagerModelTest() {} - ~EagerModelTest() {} + FlexModelTest() {} + ~FlexModelTest() {} bool Invoke(); @@ -104,7 +104,7 @@ class EagerModelTest : public ::testing::Test { private: // Helper method to add a TensorFlow op. tflite_names needs to start with - // "Eager" in order to work with the Eager delegate. + // "Flex" in order to work with the Flex delegate. void AddTfOp(const char* tflite_name, const string& tf_name, const string& nodedef_str, const std::vector& inputs, const std::vector& outputs); @@ -113,7 +113,7 @@ class EagerModelTest : public ::testing::Test { }; } // namespace testing -} // namespace eager +} // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/flex/util.cc similarity index 96% rename from tensorflow/contrib/lite/delegates/eager/util.cc rename to tensorflow/contrib/lite/delegates/flex/util.cc index 051246bf866c1f1d479102e8c3dd335ab1b1ca1e..829bc388bf4f613e82600edfc7363d0774d49878 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/flex/util.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" namespace tflite { -namespace eager { +namespace flex { TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status) { @@ -100,5 +100,5 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { } } -} // namespace eager +} // namespace flex } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/flex/util.h similarity index 89% rename from tensorflow/contrib/lite/delegates/eager/util.h rename to tensorflow/contrib/lite/delegates/flex/util.h index 930cb99cb953982fc88ab82710a6ebeb0271a970..7f910e7316e67363a6e54389f1d0cc94b3e009a0 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/flex/util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ -#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ #include "tensorflow/c/c_api_internal.h" #include "tensorflow/contrib/lite/c/c_api_internal.h" @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tflite { -namespace eager { +namespace flex { // Converts a tensorflow:Status into a TfLiteStatus. If the original status // represented an error, reports it using the given 'context'. @@ -41,7 +41,7 @@ TF_DataType GetTensorFlowDataType(TfLiteType type); // Returns the TfLiteType that corresponds to the given TF C API Data type. TfLiteType GetTensorFlowLiteType(TF_DataType); -} // namespace eager +} // namespace flex } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/flex/util_test.cc similarity index 97% rename from tensorflow/contrib/lite/delegates/eager/util_test.cc rename to tensorflow/contrib/lite/delegates/flex/util_test.cc index aebc91149ce0c4677f4cf48f87a9a2a92279fd8d..5f049e7b0a0c1f7be28d33b532157c6f9211c7c1 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/flex/util_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" #include @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { -namespace eager { +namespace flex { namespace { using tensorflow::DT_FLOAT; @@ -132,7 +132,7 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) { } } // namespace -} // namespace eager +} // namespace flex } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index c6587b3d3f6564d3818f4d7d7c1739ea02a62afa..d85e576284fac87519d7f4bb4bd76fe2619b59d5 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -518,7 +518,7 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinReshape: - if (version == 1) { + if (version == 1 && node->inputs->size == 2) { return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { return ANEURALNETWORKS_RESHAPE; diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 4d2437e7d3714e1b8b427b0c6197b295c0355b07..d180cb478566a9e5df24b2e67445f24a2f623215 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -28,6 +28,7 @@ android_binary( srcs = glob([ "app/src/main/java/**/*.java", ]), + aapt_version = "aapt", # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md index cbdeeac8790d93210a6c637953605b4ca270d3f6..7347147f997540e67c2c713b597dc90d933c5cb8 100644 --- a/tensorflow/contrib/lite/examples/android/app/README.md +++ b/tensorflow/contrib/lite/examples/android/app/README.md @@ -1,8 +1,43 @@ # TF Lite Android App Example +A simple Android example that demonstrates image classification and object +detection using the camera, as well as speech recognition using the microphone. + +## Building in Android Studio with TensorFlow Lite AAR from JCenter. +The build.gradle is configured to use TensorFlow Lite's nightly build. + +If you see a build error related to compatibility with Tensorflow Lite's Java +API (example: method X is undefined for type Interpreter), there has likely been +a backwards compatible change to the API. You will need to pull new app code +that's compatible with the nightly build and may need to first wait a few days +for our external and internal code to merge. + ## Building from Source with Bazel -1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source). +1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): + + 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites). + It's easiest with Android Studio. + + - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. + - Bazel requires Android Build Tools `26.0.1` or higher. + - You also need to install the Android Support Repository, available + through Android Studio under `Android SDK Manager -> SDK Tools -> + Android Support Repository`. + + 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) + to add SDK and NDK targets. + + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that + you have installed. + - By default, Android Studio will install the SDK to `~/Android/Sdk` and + the NDK to `~/Android/Sdk/ndk-bundle`. 2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device: diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 6fdcf78b69c6799fc2e666af1150efb88b55ff5c..21ad39a6bf75e536ed099cb6120407be880404f0 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -80,8 +80,7 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, interpreter->Invoke(); auto output = interpreter->typed_tensor(2); - auto output_number_of_pixels = - wanted_height * wanted_height * wanted_channels; + auto output_number_of_pixels = wanted_height * wanted_width * wanted_channels; for (int i = 0; i < output_number_of_pixels; i++) { if (s->input_floating) diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index ea4a5432526780b8909a6f2f7d857367655f4f36..52e71619def71a0c2130539afe8e7d00e7a24894 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -1,5 +1,12 @@ package(default_visibility = ["//visibility:private"]) +package_group( + name = "experimental", + packages = [ + "//tensorflow/contrib/lite/experimental/...", + ], +) + licenses(["notice"]) # Apache 2.0 load( @@ -51,6 +58,9 @@ cc_library( srcs = ["c_api.cc"], hdrs = ["c_api.h"], copts = tflite_copts(), + visibility = [ + ":experimental", + ], deps = [ ":c_api_internal", "//tensorflow/contrib/lite:context", @@ -68,6 +78,7 @@ cc_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/contrib/lite:kernel_api", ], ) @@ -93,6 +104,7 @@ cc_test( deps = [ ":c_api", ":c_api_experimental", + "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index c589cf71ea55aaf8d6b467c4ea17ca8a895e25a4..9c29f9d8b9ddfd311ee1f4cd20722880e87d3b46 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -26,6 +27,26 @@ limitations under the License. extern "C" { #endif // __cplusplus +namespace { +class CallbackErrorReporter : public tflite::ErrorReporter { + public: + using ErrorCallback = void (*)(void* user_data, const char* format, + va_list args); + + CallbackErrorReporter(ErrorCallback callback, void* user_data) + : callback_(callback), user_data_(user_data) {} + + int Report(const char* format, va_list args) override { + callback_(user_data_, format, args); + return 0; + } + + private: + ErrorCallback callback_; + void* user_data_; +}; +} // namespace + // LINT.IfChange TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { @@ -56,14 +77,38 @@ void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options, options->num_threads = num_threads; } +TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter( + TFL_InterpreterOptions* options, + void (*reporter)(void* user_data, const char* format, va_list args), + void* user_data) { + options->error_reporter = reporter; + options->error_reporter_user_data = user_data; +} + TFL_Interpreter* TFL_NewInterpreter( const TFL_Model* model, const TFL_InterpreterOptions* optional_options) { if (!model || !model->impl) { return nullptr; } + std::unique_ptr optional_error_reporter; + if (optional_options && optional_options->error_reporter != nullptr) { + optional_error_reporter.reset( + new CallbackErrorReporter(optional_options->error_reporter, + optional_options->error_reporter_user_data)); + } + + // TODO(b/111881878): Allow use of C API without pulling in all builtin ops. tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::InterpreterBuilder builder(*model->impl, resolver); + if (optional_options) { + resolver.AddAll(optional_options->op_resolver); + } + tflite::ErrorReporter* error_reporter = optional_error_reporter + ? optional_error_reporter.get() + : tflite::DefaultErrorReporter(); + tflite::InterpreterBuilder builder(model->impl->GetModel(), resolver, + error_reporter); + std::unique_ptr interpreter; if (builder(&interpreter) != kTfLiteOk) { return nullptr; @@ -76,7 +121,8 @@ TFL_Interpreter* TFL_NewInterpreter( } } - return new TFL_Interpreter{model->impl, std::move(interpreter)}; + return new TFL_Interpreter{model->impl, std::move(optional_error_reporter), + std::move(interpreter)}; } void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h index b429e7687003fa15b096292658658613a9203c67..f52ab8f9ed65aa288a74e4e486ac060fa9dbebe0 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ #define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ +#include #include // Eventually the various C APIs defined in context.h will be migrated into @@ -52,8 +53,9 @@ limitations under the License. extern "C" { #endif // __cplusplus -typedef TfLiteTensor TFL_Tensor; +typedef TfLiteRegistration TFL_Registration; typedef TfLiteStatus TFL_Status; +typedef TfLiteTensor TFL_Tensor; typedef TfLiteType TFL_Type; // -------------------------------------------------------------------------- @@ -85,6 +87,17 @@ TFL_CAPI_EXPORT extern void TFL_DeleteInterpreterOptions( TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads( TFL_InterpreterOptions* options, int32_t num_threads); +// Sets a custom error reporter for interpreter execution. +// +// * `reporter` takes the provided `user_data` object, as well as a C-style +// format string and arg list (see also vprintf). +// * `user_data` is optional. If provided, it is owned by the client and must +// remain valid for the duration of the interpreter lifetime. +TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter( + TFL_InterpreterOptions* options, + void (*reporter)(void* user_data, const char* format, va_list args), + void* user_data); + // -------------------------------------------------------------------------- // TFL_Interpreter provides inference from a provided model. typedef struct TFL_Interpreter TFL_Interpreter; diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc index c4dbc55cbf6b116df46553411be5337f83ceb4e7..29f8701f53407dc47adfaca8c85c86210e4cb09a 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc @@ -21,9 +21,24 @@ limitations under the License. extern "C" { #endif // __cplusplus -TFL_Status TFL_InterpreterResetVariableTensorsToZero( - TFL_Interpreter* interpreter) { - return interpreter->impl->ResetVariableTensorsToZero(); +TFL_Status TFL_InterpreterResetVariableTensors(TFL_Interpreter* interpreter) { + return interpreter->impl->ResetVariableTensors(); +} + +void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options, + TFL_BuiltinOperator op, + const TFL_Registration* registration, + int32_t min_version, + int32_t max_version) { + options->op_resolver.AddBuiltin(static_cast(op), + registration, min_version, max_version); +} + +void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options, + const char* name, + const TFL_Registration* registration, + int min_version, int max_version) { + options->op_resolver.AddCustom(name, registration, min_version, max_version); } #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h index b0ac258dcf9bf4ab603ba847f1b111a89cf2f29b..fca5d92f77caff987f6a70c3a8fd03849bce1165 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h @@ -15,16 +15,41 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ #define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/experimental/c/c_api.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus +typedef TfLiteBuiltinOperator TFL_BuiltinOperator; + // Resets all variable tensors to zero. -TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero( +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors( TFL_Interpreter* interpreter); +// Adds an op registration for a builtin operator. +// +// NOTE: The interpreter will make a copy of `registration` internally, so the +// caller should ensure that its contents (function pointers, etc...) remain +// valid for the duration of the interpreter's lifetime. A common practice is +// making the provided TFL_Registration instance static. +void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options, + TFL_BuiltinOperator op, + const TFL_Registration* registration, + int min_version, int max_version); + +// Adds an op registration for a custom operator. +// +// NOTE: The interpreter will make a copy of `registration` internally, so the +// caller should ensure that its contents (function pointers, etc...) remain +// valid for the duration of the interpreter's lifetime. A common practice is +// making the provided TFL_Registration instance static. +void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options, + const char* name, + const TFL_Registration* registration, + int min_version, int max_version); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc index db6e5251de518d2e754f853edbfb1c1edc425a83..1b1bedb75470638d4b3cfac92819e18b8fe6e65a 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc @@ -16,25 +16,40 @@ limitations under the License. #include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" #include +#include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/experimental/c/c_api.h" #include "tensorflow/contrib/lite/testing/util.h" namespace { +TfLiteRegistration* GetDummyRegistration() { + static TfLiteRegistration registration = { + .init = nullptr, + .free = nullptr, + .prepare = nullptr, + .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }, + }; + return ®istration; +} + TEST(CApiExperimentalSimple, Smoke) { TFL_Model* model = TFL_NewModelFromFile( "tensorflow/contrib/lite/testdata/add.bin"); ASSERT_NE(model, nullptr); - TFL_Interpreter* interpreter = - TFL_NewInterpreter(model, /*optional_options=*/nullptr); + TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); + TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd, + GetDummyRegistration(), 1, 1); + + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); ASSERT_NE(interpreter, nullptr); ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TFL_InterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); - EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk); - - TFL_DeleteModel(model); TFL_DeleteInterpreter(interpreter); + TFL_DeleteInterpreterOptions(options); + TFL_DeleteModel(model); } } // namespace diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h index 60c2e4e2cd93848b8e709a06225ce548ed2a7686..da3af3cad4c54865cfe778b79538e5800c284985 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -19,9 +19,13 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" // Internal structures used by the C API. These are likely to change and should // not be depended on. +// +// NOTE: This header does not follow C conventions and does not define a C API. +// It is effectively an (internal) implementation detail of the C API. struct TFL_Model { // Sharing is safe as FlatBufferModel is const. @@ -33,12 +37,24 @@ struct TFL_InterpreterOptions { kDefaultNumThreads = -1, }; int num_threads = kDefaultNumThreads; + + tflite::MutableOpResolver op_resolver; + + void (*error_reporter)(void* user_data, const char* format, + va_list args) = nullptr; + void* error_reporter_user_data = nullptr; }; struct TFL_Interpreter { // Taking a reference to the (const) model data avoids lifetime-related issues // and complexity with the TFL_Model's existence. std::shared_ptr model; + + // The interpreter does not take ownership of the provided ErrorReporter + // instance, so we ensure its validity here. Note that the interpreter may use + // the reporter in its destructor, so it should be declared first. + std::unique_ptr optional_error_reporter; + std::unique_ptr impl; }; diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc index 649dac8d1ad5475c2a174458a00c1a93c9585e2f..48a3714ec345a6f4bc4be8ebe937471a91c60218 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -85,6 +85,37 @@ TEST(CApiSimple, Smoke) { TFL_DeleteInterpreter(interpreter); } +TEST(CApiSimple, ErrorReporter) { + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/contrib/lite/testdata/add.bin"); + TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); + + // Install a custom error reporter into the interpreter by way of options. + tflite::TestErrorReporter reporter; + TFL_InterpreterOptionsSetErrorReporter( + options, + [](void* user_data, const char* format, va_list args) { + reinterpret_cast(user_data)->Report(format, + args); + }, + &reporter); + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); + + // The options/model can be deleted immediately after interpreter creation. + TFL_DeleteInterpreterOptions(options); + TFL_DeleteModel(model); + + // Invoke the interpreter before tensor allocation. + EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteError); + + // The error should propagate to the custom error reporter. + EXPECT_EQ(reporter.error_messages(), + "Invoke called on model that is not ready."); + EXPECT_EQ(reporter.num_calls(), 1); + + TFL_DeleteInterpreter(interpreter); +} + } // namespace int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2125f218ca877f94ec9f4d98928b6a1c8f2576eb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "tflite_lstm", + srcs = ["tflite_lstm.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/python:framework", + "@six_archive//:six", + ], +) + +py_test( + name = "unidirectional_sequence_lstm_test", + size = "large", + srcs = ["unidirectional_sequence_lstm_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":tflite_lstm", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python/tools:optimize_for_inference", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..2357743266f7082a5a003153718de08c83174ea5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py @@ -0,0 +1,396 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TfLite LSTMCell wrapper. + +TODO(renjieliu): Find a better home for this one. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf + +from tensorflow.contrib.lite.python import lite +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.platform import tf_logging as logging + + +class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + This is used only for TfLite, it provides hints and it also makes the + variables in the desired for the tflite ops (transposed and seaparated). + + The default non-peephole implementation is based on: + + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf + + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + """ + + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 in + order to reduce the scale of forgetting at the beginning of the + training. Must set it manually to `0.0` when restoring from CudnnLSTM + trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. When + restoring from CudnnLSTM-trained checkpoints, use + `CudnnCompatibleLSTMCell` instead. + """ + super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + # TODO(raziel): decide if we want to just support tuples (yes please!). + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + + # Inputs must be 2-dimensional. + # TODO(raziel): layers stuff -- chop if un-layerizing Op. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceLstm") + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation or math_ops.tanh + + self._output_size = num_proj if num_proj else num_units + self._state_size = ( + tf.nn.rnn_cell.LSTMStateTuple(num_units, self._output_size) + if state_is_tuple else num_units + self._output_size) + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def build(self, inputs_shape): + """Build TfLite LSTM cell graph. + + Args: + inputs_shape: The inputs_shape must be known, and is [batch_size, + input_size] shape. + + Raises: + ValueError: if the inputs_shape is invalid. + """ + if len(inputs_shape) != 2 or inputs_shape[1].value is None: + raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + maybe_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_unit_shards) + if self._num_unit_shards is not None else None) + input_weight_shape = [self._num_units, input_depth] + cell_weight_shape = [self._num_units, self._output_size] + bias_shape = [self._num_units] + + def add_variable_wrapped(name, shape, initializer, index, partitioner): + var = self.add_variable( + name, shape=shape, initializer=initializer, partitioner=partitioner) + return self._tflite_wrapper.add_input( + var, name="name", index_override=index) + + weight_initializer = self._initializer + if self.dtype is None: + bias_initializer = init_ops.zeros_initializer + else: + bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) + + self.input_to_input_w = add_variable_wrapped( + "input_to_input_w", input_weight_shape, weight_initializer, 1, + maybe_partitioner) + self.input_to_forget_w = add_variable_wrapped( + "input_to_forget_w", input_weight_shape, weight_initializer, 2, + maybe_partitioner) + self.input_to_cell_w = add_variable_wrapped( + "input_to_cell_w", input_weight_shape, weight_initializer, 3, + maybe_partitioner) + self.input_to_output_w = add_variable_wrapped( + "input_to_output_w", input_weight_shape, weight_initializer, 4, + maybe_partitioner) + self.cell_to_input_w = add_variable_wrapped( + "cell_to_input_w", cell_weight_shape, weight_initializer, 5, + maybe_partitioner) + self.cell_to_forget_w = add_variable_wrapped( + "cell_to_forget_w", cell_weight_shape, weight_initializer, 6, + maybe_partitioner) + self.cell_to_cell_w = add_variable_wrapped( + "cell_to_cell_w", cell_weight_shape, weight_initializer, 7, + maybe_partitioner) + self.cell_to_output_w = add_variable_wrapped( + "cell_to_output_w", cell_weight_shape, weight_initializer, 8, + maybe_partitioner) + + self.input_bias = add_variable_wrapped( + "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) + self.forget_bias = add_variable_wrapped( + "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner) + self.cell_bias = add_variable_wrapped( + "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) + self.output_bias = add_variable_wrapped( + "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner) + + # index 9, 10, 11. + # f stands for forget, i stands for input and o stands for output. + if self._use_peepholes: + self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], + self._initializer, 9, + maybe_partitioner) + self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], + self._initializer, 10, + maybe_partitioner) + self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], + self._initializer, 11, + maybe_partitioner) + + # index 16 for proj kernel. + if self._num_proj is not None: + maybe_proj_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_proj_shards) + if self._num_proj_shards is not None else None) + self._proj_kernel = add_variable_wrapped( + "projection/kernel", [self._num_proj, self._num_units], + self._initializer, + 16, + partitioner=maybe_proj_partitioner) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, `[batch, num_units]`. + state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, + [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple + of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + inputs = self._tflite_wrapper.add_input( + inputs, tag="input", name="input", aggregate="stack", index_override=0) + + # Make sure inputs and bias_initializer has the same type. + assert inputs.dtype == self.input_to_input_w.dtype + + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + # Note: For TfLite, cell_state is at index 19 while activation state at + # index 18. + c_prev = self._tflite_wrapper.add_input( + c_prev, + tag="c_prev", + name="c_prev", + aggregate="first", + index_override=19) + m_prev = self._tflite_wrapper.add_input( + m_prev, + tag="m_prev", + name="m_prev", + aggregate="first", + index_override=18) + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) + + # i stands for input gate. + # f stands for forget gate activation. + # o outputs. + # j output of LSTM unit. + # c is the final state. + # m is the output. + i = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_input_w, self.cell_to_input_w], axis=1), + transpose_b=True), self.input_bias) + f = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_forget_w, self.cell_to_forget_w], axis=1), + transpose_b=True), self.forget_bias) + o = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_output_w, self.cell_to_output_w], axis=1), + transpose_b=True), self.output_bias) + j = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_cell_w, self.cell_to_cell_w], axis=1), + transpose_b=True), self.cell_bias) + + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + transposed_proj_kernel = tf.transpose(self._proj_kernel) + m = math_ops.matmul(m, transposed_proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + c = self._tflite_wrapper.add_output( + c, tag="c", name="c", aggregate="last", index_override=1) + m = self._tflite_wrapper.add_output( + m, tag="m", name="m", index_override=2, aggregate="stack") + + new_state = ( + tf.nn.rnn_cell.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(TFLiteLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca977518cb11db5f7ed33afa25ead5c02221a95 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -0,0 +1,226 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.tools import optimize_for_inference_lib + +# Number of steps to train model. +TRAIN_STEPS = 1 + +CONFIG = tf.ConfigProto(device_count={"GPU": 0}) + + +class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): + + def setUp(self): + tf.reset_default_graph() + # Import MNIST dataset + self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + + # Define constants + # Unrolled through 28 time steps + self.time_steps = 28 + # Rows of 28 pixels + self.n_input = 28 + # Learning rate for Adam optimizer + self.learning_rate = 0.001 + # MNIST is meant to be classified in 10 classes(0-9). + self.n_classes = 10 + # Batch size + self.batch_size = 16 + # Lstm Units. + self.num_units = 64 + + def buildLstmLayer(self): + return tf.nn.rnn_cell.MultiRNNCell([ + TFLiteLSTMCell( + self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), + TFLiteLSTMCell(self.num_units, num_proj=64, forget_bias=0, name="rnn2"), + TFLiteLSTMCell( + self.num_units // 2, + use_peepholes=True, + num_proj=64, + forget_bias=0, + name="rnn3"), + TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4") + ]) + + def buildModel(self, lstm_layer, is_dynamic_rnn, is_train): + # Weights and biases for output softmax layer. + out_weights = tf.Variable( + tf.random_normal([self.num_units, self.n_classes])) + out_bias = tf.Variable(tf.random_normal([self.n_classes])) + + # input image placeholder + x = tf.placeholder( + "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") + + # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn. + # x is shaped [batch_size,time_steps,num_inputs] + if is_dynamic_rnn: + if is_train: + lstm_input = x + outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32") + outputs = tf.unstack(outputs, axis=1) + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + + # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units] + # by the softmax layer's out_weight of shape [num_units,n_classes] + # plus out_bias + prediction = tf.matmul(outputs[-1], out_weights) + out_bias + output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS") + + return x, prediction, output_class + + def trainModel(self, x, prediction, output_class, sess): + # input label placeholder + y = tf.placeholder("float", [None, self.n_classes]) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) + # Optimization + opt = tf.train.AdamOptimizer( + learning_rate=self.learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + sess.run(init) + for _ in range(TRAIN_STEPS): + batch_x, batch_y = self.mnist.train.next_batch( + batch_size=self.batch_size, shuffle=False) + + batch_x = batch_x.reshape((self.batch_size, self.time_steps, + self.n_input)) + sess.run(opt, feed_dict={x: batch_x, y: batch_y}) + + def saveAndRestoreModel(self, lstm_layer, sess, saver, is_dynamic_rnn): + model_dir = tempfile.mkdtemp() + saver.save(sess, model_dir) + + # Reset the graph. + tf.reset_default_graph() + x, prediction, output_class = self.buildModel( + lstm_layer, is_dynamic_rnn, is_train=False) + + new_sess = tf.Session(config=CONFIG) + saver = tf.train.Saver() + saver.restore(new_sess, model_dir) + return x, prediction, output_class, new_sess + + def getInferenceResult(self, x, output_class, sess): + b1, _ = self.mnist.train.next_batch(batch_size=1) + sample_input = np.reshape(b1, (1, self.time_steps, self.n_input)) + + expected_output = sess.run(output_class, feed_dict={x: sample_input}) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, [output_class.op.name]) + return sample_input, expected_output, frozen_graph + + def tfliteInvoke(self, graph, test_inputs, outputs): + tf.reset_default_graph() + # Turn the input into placeholder of shape 1 + tflite_input = tf.placeholder( + "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE") + tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) + with tf.Session() as sess: + curr = sess.graph_def + curr = tf.contrib.lite.convert_op_hints_to_stubs(graph_def=curr) + + curr = optimize_for_inference_lib.optimize_for_inference( + curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], + [tf.float32.as_datatype_enum]) + + tflite = tf.contrib.lite.toco_convert( + curr, [tflite_input], [outputs], allow_custom_ops=False) + interpreter = tf.contrib.lite.Interpreter(model_content=tflite) + + try: + interpreter.allocate_tensors() + except ValueError: + assert False + + input_index = (interpreter.get_input_details()[0]["index"]) + interpreter.set_tensor(input_index, test_inputs) + interpreter.invoke() + output_index = (interpreter.get_output_details()[0]["index"]) + result = interpreter.get_tensor(output_index) + # Reset all variables so it will not pollute other inferences. + interpreter.reset_all_variables() + return result + + def testStaticRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + # Since we don't yet support OpHints for dynamic, we will load the model + # back in as a static model. This requires the variables to have the same + # names as if they were trained as a static. Thus, we get rid of while/rnn + # names. + variables_to_save = {} + for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): + op_name = i.name + if op_name.startswith("while/rnn/"): + op_name = op_name.split("while/rnn/")[1] + if op_name.endswith(":0"): + op_name = op_name.split(":0")[0] + variables_to_save[op_name] = i + saver = tf.train.Saver(variables_to_save) + + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc index 8442c4d46ca74f7645ffd68dc8e50bd46f7addb4..b1ebe4a804a971043d19b588f07ffc54b1d1aa38 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index aa42b495bdfb5823e4a94b004793837004b69429..942dbbbeae553ba55ea75b3257aca28b9b12eb77 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/experimental/micro/BUILD b/tensorflow/contrib/lite/experimental/micro/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df1036bc8b9cc84f4b63ae2a771e3aa8f8989060 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/BUILD @@ -0,0 +1,76 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_framework", + srcs = [ + "micro_error_reporter.cc", + "micro_interpreter.cc", + "micro_mutable_op_resolver.cc", + "simple_tensor_allocator.cc", + ], + hdrs = [ + "compatibility.h", + "micro_error_reporter.h", + "micro_interpreter.h", + "micro_mutable_op_resolver.h", + "simple_tensor_allocator.h", + ], + deps = [ + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +tflite_micro_cc_test( + name = "micro_error_reporter_test", + srcs = [ + "micro_error_reporter_test.cc", + ], + deps = [ + ":micro_framework", + ], +) + +tflite_micro_cc_test( + name = "micro_mutable_op_resolver_test", + srcs = [ + "micro_mutable_op_resolver_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "micro_interpreter_test", + srcs = [ + "micro_interpreter_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "simple_tensor_allocator_test", + srcs = [ + "simple_tensor_allocator_test.cc", + ], + deps = [ + ":micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/README.md b/tensorflow/contrib/lite/experimental/micro/README.md new file mode 100644 index 0000000000000000000000000000000000000000..414cafde4d489eac36f739f163033bf27f0fc818 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/README.md @@ -0,0 +1,114 @@ +# TensorFlow Lite for Microcontrollers + +This an experimental port of TensorFlow Lite aimed at micro controllers and other devices with only kilobytes of memory. It doesn't require any operating system support, any standard C or C++ libraries, or dynamic memory allocation, so it's designed to be portable even to 'bare metal' systems. The core runtime fits in 16KB on a Cortex M3, and with enough operators to run a speech keyword detection model, takes up a total of 22KB. + +The design goals are for the framework to be: + +- **Readable**: We want embedded software engineers to be able to understand what's required to run ML inference without having to study research papers. We've tried to keep the code base small, modular, and have reference implementations of all operations to help with this. + +- **Easy to modify**: We know that there are a lot of different platforms and requirements in the embedded world, and we don't expect to cover all of them in one framework. Instead, we're hoping that it can be a good starting point for developers to build on top of to meet their own needs. For example, we tried to make it easy to replace the implementations of key computational operators that are often crucial for performance, without having to touch the data flow and other runtime code. We want it to make more sense to use our workflow to handle things like model import and less-important operations, and customize the parts that matter, rather than having to reimplement everything in your own engine. + +- **Well-tested**: If you're modifying code, you need to know if your changes are correct. Having an easy way to test lets you develop much faster. To help there, we've written tests for all the components, and we've made sure that the tests can be run on almost any platform, with no dependencies apart from the ability to log text to a debug console somewhere. We also provide an easy way to run all the tests on-device as part of an automated test framework, and we use qemu/Renode emulation so that tests can be run even without physical devices present. + +- **Easy to integrate**: We want to be as open a system as possible, and use the best code available for each platform. To do that, we're going to rely on projects like [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to handle as much performance-critical code as possible. We know that there are an increasing number of options to accelerate neural networks on microcontrollers, so we're aiming to be a good host for deploying those hardware technologies too. + +- **Compatible**: We're using the same file schema, interpreter API, and kernel interface as regular TensorFlow Lite, so we leverage the large existing set of tools, documentation, and examples for the project. The biggest barrier to deploying ML models is getting them from a training environment into a form that's easy to run inference on, so we see reusing this rich ecosystem as being crucial to being easily usable. We also hope to integrate this experimental work back into the main codebase in the future. + +To meet those goals, we've made some tradeoffs: + +- **Simple C++**: To help with readability, our code is written in a modern version of C++, but we generally treat it as a "better C", rather relying on more complex features such as template meta-programming. As mentioned earlier, we avoid any use of dynamic memory allocation (new/delete) or the standard C/C++ libraries, so we believe this should still be fairly portable. It does mean that some older devices with C-only toolchains won't be supported, but we're hoping that the reference operator implementations (which are simple C-like functions) can still be useful in those cases. The interfaces are also designed to be C-only, so it should be possible to integrate the resulting library with pure C projects. + +- **Interpreted**: Code generation is a popular pattern for embedded code, because it gives standalone code that's easy to modify and step through, but we've chosen to go with an interpreted approach. In our internal microcontroller work we've found that using an extremely stripped-down interpreter with almost no dependencies gives us a lot of the same advantages, but is easier to maintain. For example, when new updates come out for the underlying library, you can just merge your local modifications in a single step, rather than having to regenerate new code and then patch in any changes you subsequently made. The coarse granularity of the interpreted primitives means that each operation call typically takes hundreds of thousands of instruction cycles at least, so we don't see noticeable performance gains from avoiding what's essentially a single switch statement at the interpreter level to call each operation. We're still working on improving the packaging though, for example we're considering having the ability to snapshot all the source files and headers used for a particular model, being able to compile the code and data together as a library, and then access it through a minimal set of C interface calls which hide the underlying complexity. + +- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version. + +- **Code Duplication**: Some of the code in this prototype largely duplicates the logic in other parts of the TensorFlow Lite code base, for example the operator wrappers. We've tried to keep share as much as we can between the two interpreters, but there are some assumptions built into the original runtime that make this difficult. We'll be working on modularizing the main interpreter so that we can move to an entirely shared system. + +This initial preview release is designed to get early feedback, and is not intended to be a final product. It only includes enough operations to run a simple keyword recognition model, and the implementations are not optimized. We're hoping this will be a good way to get feedback and collaborate to improve the framework. + +## Getting Started + +Building requires a Linux or OS X machine. + + - Open a terminal + - Download the TensorFlow source with `git clone https://github.com/tensorflow` + - Enter the source root directory by running `cd tensorflow` + - Download the dependencies by running `tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes + - Build and test the library with `make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test` + +You should see a series of compilation steps, followed by "~~~ALL TESTS PASSED~~~" for the various tests of the code that it will run. If there's an error, you should get an informative message from make about what went wrong. + +These tests are all built as simple binaries with few dependencies, so you can run them manually. For example, here's how to run the depthwise convolution test, and its output: + +``` +tensorflow/contrib/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test + +Testing SimpleTest +Testing SimpleTestQuantized +Testing SimpleTestRelu +Testing SimpleTestReluQuantized +4/4 tests passed +~ALL TESTS PASSED~~~ +``` + +Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this: + +``` +... +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { +... +} +... +TF_LITE_MICRO_TESTS_END +``` + +These macros work a lot like [the Google test framework](https://github.com/google/googletest), but they don't require any dependencies and just write results to stderr, rather than aborting the program. If all the tests pass, then "~~~ALL TESTS PASSED~~~" is output, and the test harness that runs the binary during the make process knows that everything ran correctly. If there's an error, the lack of the expected string lets the harness know that the test failed. + +So, why are we running tests in this complicated way? So far, we've been building binaries that run locally on the Mac OS or Linux machine you're building on, but this approach becomes important when we're targeting simple micro controller devices. + +## Building for the "Blue Pill" STM32F103 + +The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/googletest) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips. + +It's fairly easy to [buy and wire up a physical board](https://github.com/google/stm32_bare_lib#wiring-up-your-blue-pill), but even if you don't have an actual device, the [Renode project](https://renode.io/) makes it easy to run a faithful emulation on your desktop machine. You'll need [Docker](https://www.docker.com/) installed, but once you have that set up, try running the following command: + +`make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test` + +You should see a similar set of outputs as you did in the previous section, with the addition of some extra Docker logging messages. These are because we're using Docker to run the Renode micro controller emulation tool, and the tests themselves are being run on a simulated STM32F103 device. The communication channels between an embedded device and the host are quite limited, so the test harness looks at the output of the debug log to see if tests have passed, just as it did in the previous section. This makes it a very flexible way to run cross-platform tests, even when a platform has no operating system facilities, as long as it can output debugging text logs. + +To understand what's happening here, try running the same depthwise convolution test, but through the emulated device test harness, with the following command: + +``` +tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh \ +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test + +``` + +You should see output that looks something like this: + +``` +Sending build context to Docker daemon 21.5kB +Step 1/2 : FROM antmicro/renode:latest + ---> 1b670a243e8f +Step 2/2 : LABEL maintainer="Pete Warden " + ---> Using cache + ---> 3afcd410846d +Successfully built 3afcd410846d +Successfully tagged renode_bluepill:latest +LOGS: +... +03:27:32.4340 [INFO] machine-0: Machine started. +03:27:32.4790 [DEBUG] cpu.uartSemihosting: [+0.22s host +0s virt 0s virt from start] Testing SimpleTest +03:27:32.4812 [DEBUG] cpu.uartSemihosting: [+2.21ms host +0s virt 0s virt from start] Testing SimpleTestQuantized +03:27:32.4833 [DEBUG] cpu.uartSemihosting: [+2.14ms host +0s virt 0s virt from start] Testing SimpleTestRelu +03:27:32.4834 [DEBUG] cpu.uartSemihosting: [+0.18ms host +0s virt 0s virt from start] Testing SimpleTestReluQuantized +03:27:32.4838 [DEBUG] cpu.uartSemihosting: [+0.4ms host +0s virt 0s virt from start] 4/4 tests passed +03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+41µs host +0s virt 0s virt from start] ~~~ALL TESTS PASSED~~~ +03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+5µs host +0s virt 0s virt from start] +... +tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test: PASS +``` + +There's a lot of output here, but you should be able to see that the same tests that were covered when we ran locally on the development machine show up in the debug logs here, along with the magic string "~~~ALL TESTS PASSED~~~". This is the exact same code as before, just compiled and run on the STM32F103 rather than your desktop. We hope that the simplicity of this testing approach will help make adding support for new platforms as easy as possible. diff --git a/tensorflow/contrib/lite/experimental/micro/compatibility.h b/tensorflow/contrib/lite/experimental/micro/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..4f0fd9f3120a5db74cdfb84e7b17a0f3656520bc --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/compatibility.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ + +// C++ will automatically create class-specific delete operators for virtual +// objects, which by default call the global delete function. For embedded +// applications we want to avoid this, and won't be calling new/delete on these +// objects, so we need to override the default implementation with one that does +// nothing to avoid linking in ::delete(). +// This macro needs to be included in all subclasses of a virtual base class in +// the private section. +#ifdef TF_LITE_STATIC_MEMORY +#define TF_LITE_REMOVE_VIRTUAL_DELETE \ + void operator delete(void* p) {} +#else +#define TF_LITE_REMOVE_VIRTUAL_DELETE +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dad58b6c1cc818d3ae68dd4fdf5ec47315e1b5cc --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD @@ -0,0 +1,31 @@ +# Description: +# TensorFlow Lite microcontroller example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +tflite_micro_cc_test( + name = "micro_speech_test", + srcs = [ + "micro_speech_test.cc", + "tiny_conv_model_data.cc", + "tiny_conv_model_data.h", + ], + tags = [ + "nomsan", + ], + deps = [ + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/kernels:all_ops_resolver", + "//tensorflow/contrib/lite/experimental/micro/kernels:micro_ops", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..86cd056a7216aa57126be3f6e660a7dcee0c6c44 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestInvoke) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + + const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data); + if (model->version() != TFLITE_SCHEMA_VERSION) { + error_reporter->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model->version(), TFLITE_SCHEMA_VERSION); + } + tflite::ops::micro::AllOpsResolver resolver; + + const int tensor_arena_size = 10 * 1024; + uint8_t tensor_arena[tensor_arena_size]; + tflite::SimpleTensorAllocator tensor_allocator(tensor_arena, + tensor_arena_size); + + tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator, + error_reporter); + TfLiteStatus invoke_status = interpreter.Invoke(); + if (invoke_status != kTfLiteOk) { + error_reporter->Report("Invoke failed\n"); + } + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status); + + error_reporter->Report("Ran successfully\n"); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1f9e0e21994b0a79241690e533e4edc8bfe5565 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc @@ -0,0 +1,1672 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Automatically created from a TensorFlow Lite flatbuffer using the command: +// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc + +#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h" + +const unsigned char g_tiny_conv_model_data[] = { + 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x4d, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xf4, 0x47, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, + 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74, + 0x65, 0x64, 0x2e, 0x00, 0x09, 0x00, 0x00, 0x00, 0xd4, 0x47, 0x00, 0x00, + 0x04, 0x03, 0x00, 0x00, 0xfc, 0x02, 0x00, 0x00, 0xf4, 0x02, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xb3, 0xff, 0xff, + 0x16, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xd7, 0x02, 0x00, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe8, 0xb3, 0xff, 0xff, + 0x46, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0xab, 0x00, 0x00, 0x00, 0x1e, 0xff, 0xff, 0xff, 0xed, 0xff, 0xff, 0xff, + 0x4a, 0x00, 0x00, 0x00, 0x62, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x80, 0x02, 0x00, 0x00, 0xce, 0xad, 0xaf, 0x3c, 0xc8, 0xe9, 0xb0, 0x83, + 0xa1, 0xbf, 0xb2, 0xb1, 0xab, 0xd0, 0xa7, 0x53, 0xa5, 0xe9, 0xb5, 0xac, + 0xa2, 0xd3, 0xc4, 0x9e, 0x8b, 0xb2, 0x64, 0xb3, 0x9d, 0xa2, 0xae, 0xa6, + 0xd5, 0xbe, 0x43, 0x9f, 0x9c, 0x54, 0xb5, 0xa8, 0x49, 0x78, 0x86, 0xa2, + 0xa3, 0x55, 0x35, 0x96, 0x3d, 0x7f, 0xe2, 0xb5, 0xb0, 0x47, 0x28, 0xa9, + 0x9d, 0xbb, 0xd6, 0xff, 0xb7, 0x79, 0x63, 0xb5, 0xaf, 0xa7, 0xab, 0x7e, + 0xbc, 0xc7, 0xa0, 0xc3, 0xb1, 0xb6, 0xb2, 0xa1, 0xc2, 0xbb, 0x79, 0x57, + 0xbe, 0xc1, 0xb7, 0xb0, 0x6b, 0xb7, 0xa5, 0x75, 0x97, 0xb8, 0xe7, 0xac, + 0xad, 0x7e, 0xb1, 0x9b, 0xc3, 0xba, 0x6b, 0xa2, 0x7f, 0x58, 0xb9, 0x7a, + 0x4c, 0x91, 0x74, 0x9e, 0xa7, 0x3d, 0xc2, 0x94, 0x75, 0xa1, 0xa4, 0xac, + 0xab, 0x45, 0x2e, 0xb4, 0xb6, 0xbf, 0xc1, 0xdb, 0xaf, 0x6c, 0x67, 0xb1, + 0xa9, 0xa6, 0xa8, 0xca, 0xc2, 0xc4, 0xb9, 0xbf, 0xb4, 0xb9, 0xaa, 0x9d, + 0x9f, 0xb9, 0xb2, 0x71, 0xb2, 0xca, 0xbe, 0xaf, 0x5f, 0xbc, 0xa0, 0x5b, + 0xa8, 0xb4, 0xa4, 0xa8, 0xd8, 0x69, 0xb7, 0x8a, 0xbc, 0xb8, 0xaf, 0x9c, + 0x7c, 0x5d, 0xb3, 0x6b, 0x49, 0x95, 0x64, 0xa0, 0xa2, 0x49, 0xcb, 0x87, + 0xa5, 0xb5, 0xa1, 0xb2, 0xa3, 0x40, 0x6d, 0x9f, 0xc5, 0xb6, 0xbb, 0xd4, + 0x9c, 0x6d, 0x69, 0xa9, 0xa8, 0x91, 0xad, 0xb8, 0xd2, 0xc6, 0xaf, 0xb8, + 0xac, 0xa9, 0xa2, 0xa7, 0x60, 0xa6, 0xa1, 0xc9, 0xb8, 0xd6, 0xcf, 0xb1, + 0x56, 0xb4, 0xac, 0x40, 0xae, 0xbd, 0xbf, 0xa2, 0x54, 0x72, 0x9b, 0x8c, + 0xc2, 0xb5, 0xc2, 0x9b, 0x64, 0x6d, 0xb4, 0x62, 0x4e, 0x9b, 0x6c, 0xa6, + 0x8f, 0x4c, 0xca, 0x95, 0xb6, 0xbf, 0x92, 0xae, 0x9c, 0x49, 0xae, 0xb2, + 0xc0, 0xb6, 0xbc, 0xd1, 0xa4, 0x7b, 0x64, 0xa0, 0xa6, 0x81, 0xac, 0xa6, + 0xbd, 0xc8, 0xbc, 0xae, 0xaa, 0x9e, 0x61, 0xb1, 0x57, 0xac, 0xbf, 0xbf, + 0xbb, 0xe0, 0xa6, 0xae, 0x47, 0xc9, 0xbc, 0x57, 0xb0, 0xb5, 0xc7, 0x98, + 0xf4, 0x93, 0xb6, 0x70, 0xc3, 0xb3, 0xca, 0xab, 0x77, 0x9a, 0xac, 0x45, + 0x5c, 0x9e, 0x9a, 0xa9, 0x9b, 0x35, 0xc0, 0x6f, 0xc6, 0xc7, 0x91, 0xb4, + 0xa8, 0x3c, 0xce, 0xb8, 0xad, 0xb9, 0xb5, 0xdd, 0x9c, 0x6d, 0xbf, 0x91, + 0xb2, 0x7d, 0xa0, 0xaf, 0x9f, 0xbd, 0xb9, 0xcf, 0x9b, 0x5d, 0x3f, 0xac, + 0x64, 0xae, 0xaf, 0xb8, 0xbc, 0xb8, 0x86, 0xb5, 0x36, 0xcf, 0xb4, 0xa9, + 0xad, 0xcd, 0xdb, 0xa4, 0x68, 0xa6, 0xa4, 0x67, 0xc8, 0xb7, 0xe5, 0xa4, + 0x76, 0xb8, 0xa8, 0x28, 0x6b, 0xa5, 0xba, 0xad, 0x9f, 0x3a, 0xa5, 0x42, + 0xc5, 0xb0, 0x88, 0xad, 0xa5, 0x4d, 0xea, 0x8a, 0xb8, 0xb5, 0xb3, 0xd9, + 0xa0, 0x77, 0xbb, 0x92, 0x9e, 0x80, 0xbd, 0xbd, 0x6d, 0xcc, 0xab, 0x99, + 0x88, 0x58, 0x4d, 0xb0, 0x6c, 0xbc, 0x96, 0xbd, 0xae, 0xab, 0x5b, 0xac, + 0x2f, 0xc3, 0x9a, 0xbe, 0xac, 0xb3, 0x84, 0x9b, 0xe3, 0xaf, 0x95, 0x6b, + 0xc2, 0xb5, 0xca, 0xb7, 0x4e, 0xbc, 0x9d, 0x24, 0x75, 0xa9, 0xd2, 0xae, + 0xa0, 0x2b, 0x90, 0x34, 0xd1, 0xb5, 0x96, 0xae, 0xaa, 0x4d, 0xc1, 0xa3, + 0xb1, 0xb4, 0xaa, 0xd2, 0x9c, 0x7d, 0xc0, 0x91, 0x91, 0x7a, 0xb8, 0x83, + 0x44, 0xcb, 0xaf, 0x9b, 0x6b, 0x5b, 0x75, 0xb2, 0x62, 0xb6, 0xaa, 0xcb, + 0x99, 0xa8, 0x63, 0xae, 0x24, 0xc7, 0x8a, 0xbe, 0xa9, 0xb6, 0xa0, 0xa1, + 0x41, 0xac, 0x84, 0xb5, 0xb9, 0xb3, 0x9b, 0xad, 0x77, 0xbf, 0xa8, 0x7e, + 0x82, 0xb9, 0xbe, 0xaa, 0xa3, 0x47, 0x6d, 0xb5, 0xc3, 0xb1, 0xbf, 0xa7, + 0xb1, 0x57, 0x75, 0xb5, 0xb0, 0xb6, 0xb9, 0xce, 0xa4, 0x86, 0xb0, 0xa4, + 0x98, 0x80, 0xc5, 0x3e, 0x90, 0xca, 0x9b, 0xa2, 0x5a, 0x50, 0xc5, 0xa5, + 0xad, 0xc1, 0x9c, 0x91, 0x83, 0x8f, 0x21, 0xab, 0xac, 0xba, 0x70, 0xb4, + 0xae, 0x85, 0x7e, 0xa7, 0xbd, 0xba, 0x7c, 0xb2, 0xb5, 0xb2, 0x7e, 0xb3, + 0xc3, 0xcd, 0x82, 0xac, 0x9b, 0xb3, 0xa6, 0xb0, 0xbc, 0x6f, 0x52, 0xb9, + 0xbf, 0xb1, 0xa6, 0xa4, 0xc1, 0x7a, 0x90, 0xc0, 0xae, 0xab, 0x94, 0xd8, + 0xab, 0xa4, 0x98, 0xbb, 0x8b, 0x86, 0x94, 0x01, 0xad, 0xe7, 0xb1, 0x9b, + 0x57, 0x48, 0xc1, 0x88, 0xbf, 0xcc, 0xb4, 0x4b, 0x62, 0x8b, 0x48, 0xa7, + 0xbe, 0xe1, 0x80, 0xa6, 0xb3, 0x64, 0xaa, 0xa4, 0xcf, 0xba, 0x6d, 0xa6, + 0xb8, 0xa0, 0x8f, 0xb3, 0xce, 0xc3, 0x87, 0xb2, 0xa0, 0xc0, 0x78, 0xb0, + 0xb9, 0xaa, 0x40, 0xb8, 0xd8, 0xa3, 0x9a, 0xaa, 0xcc, 0xa2, 0x9f, 0xb9, + 0xbe, 0xc2, 0x89, 0xd6, 0xc6, 0x9c, 0xa3, 0xc7, 0x94, 0xb6, 0xff, 0xff, + 0x98, 0xb6, 0xff, 0xff, 0xf6, 0xb6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0xc0, 0x44, 0x00, 0x00, 0x4a, 0x4d, 0x59, 0x60, 0x5a, 0x45, 0x3d, 0x50, + 0x4a, 0x43, 0x3d, 0x59, 0x3e, 0x49, 0x4a, 0x59, 0x45, 0x44, 0x41, 0x5d, + 0x50, 0x2f, 0x4e, 0x34, 0x46, 0x48, 0x41, 0x4a, 0x4c, 0x3b, 0x4b, 0x3e, + 0x49, 0x49, 0x43, 0x4b, 0x3e, 0x49, 0x47, 0x41, 0x3e, 0x4a, 0x46, 0x43, + 0x41, 0x43, 0x47, 0x49, 0x4a, 0x4c, 0x46, 0x58, 0x3f, 0x4c, 0x4b, 0x4c, + 0x4d, 0x4b, 0x45, 0x52, 0x45, 0x42, 0x52, 0x52, 0x48, 0x40, 0x46, 0x5f, + 0x4c, 0x41, 0x47, 0x48, 0x48, 0x4c, 0x43, 0x61, 0x50, 0x4b, 0x49, 0x49, + 0x46, 0x3f, 0x40, 0x67, 0x40, 0x4d, 0x45, 0x40, 0x40, 0x45, 0x47, 0x56, + 0x44, 0x3a, 0x4a, 0x4c, 0x52, 0x48, 0x46, 0x50, 0x4b, 0x44, 0x51, 0x45, + 0x40, 0x45, 0x45, 0x48, 0x4e, 0x4e, 0x43, 0x48, 0x44, 0x4b, 0x45, 0x4a, + 0x53, 0x45, 0x4a, 0x4b, 0x3f, 0x43, 0x45, 0x53, 0x4d, 0x43, 0x46, 0x3f, + 0x47, 0x4e, 0x51, 0x50, 0x48, 0x4f, 0x4f, 0x4a, 0x4a, 0x4e, 0x45, 0x4e, + 0x46, 0x41, 0x4a, 0x46, 0x45, 0x47, 0x45, 0x4b, 0x50, 0x4c, 0x46, 0x45, + 0x41, 0x47, 0x41, 0x47, 0x46, 0x4f, 0x3f, 0x4f, 0x4a, 0x51, 0x4f, 0x53, + 0x54, 0x48, 0x51, 0x43, 0x4b, 0x48, 0x4d, 0x46, 0x48, 0x4f, 0x49, 0x44, + 0x43, 0x53, 0x50, 0x59, 0x56, 0x3d, 0x45, 0x44, 0x48, 0x38, 0x3b, 0x5f, + 0x39, 0x43, 0x43, 0x52, 0x46, 0x3e, 0x43, 0x58, 0x43, 0x1e, 0x50, 0x3c, + 0x46, 0x4b, 0x46, 0x50, 0x3c, 0x37, 0x4c, 0x47, 0x47, 0x4b, 0x47, 0x54, + 0x43, 0x3e, 0x47, 0x4f, 0x4b, 0x41, 0x53, 0x50, 0x42, 0x46, 0x4f, 0x4b, + 0x4e, 0x3f, 0x49, 0x52, 0x4a, 0x4a, 0x49, 0x53, 0x52, 0x47, 0x52, 0x5a, + 0x40, 0x42, 0x4d, 0x4b, 0x50, 0x43, 0x49, 0x59, 0x47, 0x4c, 0x4d, 0x50, + 0x4e, 0x3c, 0x44, 0x61, 0x51, 0x49, 0x49, 0x46, 0x49, 0x47, 0x4b, 0x5a, + 0x45, 0x4b, 0x43, 0x40, 0x44, 0x52, 0x4d, 0x54, 0x49, 0x47, 0x44, 0x48, + 0x46, 0x48, 0x3e, 0x40, 0x45, 0x4f, 0x4d, 0x4b, 0x4c, 0x40, 0x3d, 0x40, + 0x3e, 0x48, 0x50, 0x4e, 0x4c, 0x42, 0x48, 0x4b, 0x3d, 0x48, 0x4b, 0x44, + 0x52, 0x4b, 0x49, 0x4f, 0x49, 0x3f, 0x47, 0x43, 0x4d, 0x3f, 0x53, 0x4e, + 0x4a, 0x4f, 0x4e, 0x4e, 0x53, 0x42, 0x46, 0x4c, 0x44, 0x4c, 0x46, 0x51, + 0x45, 0x48, 0x4a, 0x50, 0x47, 0x41, 0x45, 0x54, 0x4a, 0x44, 0x50, 0x49, + 0x48, 0x50, 0x51, 0x4b, 0x50, 0x4c, 0x4a, 0x49, 0x43, 0x47, 0x50, 0x4a, + 0x4d, 0x4c, 0x4e, 0x49, 0x42, 0x50, 0x52, 0x48, 0x45, 0x5a, 0x4e, 0x55, + 0x51, 0x3d, 0x3d, 0x4d, 0x42, 0x32, 0x36, 0x64, 0x39, 0x4c, 0x41, 0x48, + 0x44, 0x35, 0x43, 0x56, 0x47, 0x1e, 0x4b, 0x3e, 0x47, 0x3f, 0x43, 0x52, + 0x51, 0x34, 0x41, 0x4d, 0x3e, 0x41, 0x41, 0x48, 0x3c, 0x4b, 0x45, 0x3b, + 0x40, 0x43, 0x4c, 0x46, 0x46, 0x47, 0x3e, 0x4f, 0x4b, 0x48, 0x42, 0x47, + 0x4e, 0x3e, 0x49, 0x47, 0x43, 0x43, 0x4e, 0x52, 0x51, 0x45, 0x3f, 0x54, + 0x46, 0x44, 0x48, 0x5d, 0x3e, 0x4a, 0x47, 0x52, 0x53, 0x3a, 0x4f, 0x5d, + 0x41, 0x4c, 0x48, 0x51, 0x43, 0x4b, 0x4b, 0x67, 0x48, 0x4b, 0x45, 0x4d, + 0x4b, 0x43, 0x4a, 0x54, 0x4c, 0x46, 0x43, 0x4a, 0x4d, 0x43, 0x4c, 0x47, + 0x4a, 0x48, 0x4d, 0x42, 0x4d, 0x48, 0x3f, 0x43, 0x4c, 0x44, 0x4e, 0x4c, + 0x40, 0x45, 0x4b, 0x48, 0x47, 0x47, 0x3e, 0x4c, 0x52, 0x41, 0x44, 0x4e, + 0x4d, 0x44, 0x49, 0x4d, 0x3d, 0x45, 0x48, 0x4f, 0x4c, 0x4a, 0x55, 0x51, + 0x4d, 0x4c, 0x45, 0x4e, 0x46, 0x45, 0x44, 0x49, 0x4e, 0x44, 0x40, 0x48, + 0x49, 0x44, 0x53, 0x51, 0x42, 0x41, 0x51, 0x49, 0x51, 0x45, 0x51, 0x3f, + 0x4b, 0x3f, 0x52, 0x3c, 0x50, 0x4d, 0x4f, 0x4b, 0x44, 0x4f, 0x40, 0x52, + 0x49, 0x4a, 0x50, 0x3f, 0x3d, 0x54, 0x4c, 0x53, 0x52, 0x45, 0x41, 0x43, + 0x47, 0x2d, 0x40, 0x63, 0x3a, 0x51, 0x43, 0x4e, 0x40, 0x2b, 0x36, 0x5b, + 0x4b, 0x12, 0x4d, 0x35, 0x4b, 0x3f, 0x44, 0x4a, 0x46, 0x31, 0x54, 0x48, + 0x43, 0x42, 0x3d, 0x51, 0x41, 0x45, 0x49, 0x4b, 0x47, 0x49, 0x3d, 0x3e, + 0x46, 0x3d, 0x4d, 0x48, 0x3d, 0x45, 0x48, 0x4b, 0x49, 0x52, 0x44, 0x4c, + 0x45, 0x44, 0x45, 0x49, 0x50, 0x48, 0x45, 0x46, 0x45, 0x44, 0x52, 0x55, + 0x46, 0x45, 0x4b, 0x3d, 0x42, 0x4a, 0x3e, 0x57, 0x48, 0x4b, 0x3c, 0x42, + 0x4a, 0x46, 0x47, 0x6c, 0x54, 0x4b, 0x41, 0x49, 0x49, 0x50, 0x43, 0x56, + 0x44, 0x43, 0x4d, 0x3e, 0x44, 0x41, 0x47, 0x40, 0x4a, 0x4b, 0x4d, 0x4d, + 0x3e, 0x46, 0x45, 0x47, 0x3e, 0x42, 0x4a, 0x45, 0x49, 0x3d, 0x3f, 0x43, + 0x40, 0x44, 0x47, 0x4a, 0x45, 0x4d, 0x4b, 0x4c, 0x43, 0x40, 0x3d, 0x3e, + 0x4c, 0x4c, 0x42, 0x4d, 0x48, 0x4d, 0x49, 0x42, 0x51, 0x51, 0x4c, 0x4b, + 0x53, 0x4f, 0x48, 0x4d, 0x40, 0x46, 0x45, 0x4b, 0x47, 0x47, 0x4b, 0x46, + 0x54, 0x42, 0x42, 0x46, 0x46, 0x4a, 0x4c, 0x55, 0x3f, 0x3c, 0x52, 0x4b, + 0x4b, 0x4d, 0x4e, 0x48, 0x53, 0x4c, 0x4b, 0x42, 0x52, 0x54, 0x50, 0x4b, + 0x40, 0x5f, 0x58, 0x53, 0x50, 0x42, 0x35, 0x48, 0x39, 0x24, 0x3c, 0x5e, + 0x41, 0x50, 0x3c, 0x51, 0x42, 0x26, 0x42, 0x56, 0x41, 0x0c, 0x3e, 0x3d, + 0x48, 0x3e, 0x50, 0x4b, 0x3a, 0x2c, 0x43, 0x3d, 0x48, 0x3e, 0x43, 0x48, + 0x4c, 0x3f, 0x4a, 0x3e, 0x51, 0x4a, 0x4f, 0x40, 0x47, 0x43, 0x50, 0x4c, + 0x43, 0x4d, 0x3f, 0x45, 0x4d, 0x3e, 0x4c, 0x44, 0x51, 0x47, 0x4b, 0x51, + 0x45, 0x49, 0x44, 0x3f, 0x46, 0x46, 0x46, 0x57, 0x49, 0x4c, 0x49, 0x4e, + 0x47, 0x4c, 0x47, 0x5e, 0x43, 0x46, 0x45, 0x4b, 0x52, 0x49, 0x45, 0x5f, + 0x47, 0x41, 0x46, 0x43, 0x4f, 0x3b, 0x43, 0x51, 0x46, 0x53, 0x4a, 0x4e, + 0x4b, 0x43, 0x4e, 0x40, 0x48, 0x49, 0x46, 0x3f, 0x48, 0x50, 0x4b, 0x41, + 0x4a, 0x47, 0x4b, 0x3d, 0x46, 0x49, 0x4b, 0x43, 0x43, 0x42, 0x3e, 0x47, + 0x47, 0x4a, 0x45, 0x46, 0x51, 0x48, 0x51, 0x4e, 0x3f, 0x50, 0x44, 0x4b, + 0x4d, 0x4e, 0x44, 0x4d, 0x3d, 0x49, 0x4a, 0x4e, 0x42, 0x51, 0x43, 0x42, + 0x46, 0x3e, 0x48, 0x4b, 0x4f, 0x50, 0x3d, 0x48, 0x4c, 0x4f, 0x46, 0x44, + 0x44, 0x48, 0x42, 0x4b, 0x48, 0x41, 0x43, 0x46, 0x4d, 0x49, 0x4f, 0x43, + 0x41, 0x44, 0x3f, 0x3d, 0x45, 0x4f, 0x45, 0x41, 0x40, 0x58, 0x4f, 0x54, + 0x5b, 0x4b, 0x3a, 0x47, 0x3d, 0x28, 0x3d, 0x57, 0x3e, 0x51, 0x3f, 0x47, + 0x3f, 0x2e, 0x3e, 0x54, 0x4e, 0x0b, 0x41, 0x3d, 0x3b, 0x3d, 0x43, 0x47, + 0x47, 0x28, 0x4d, 0x43, 0x43, 0x3b, 0x4e, 0x4a, 0x4d, 0x42, 0x51, 0x46, + 0x4f, 0x3d, 0x4c, 0x3a, 0x49, 0x49, 0x4a, 0x43, 0x42, 0x4b, 0x47, 0x42, + 0x42, 0x49, 0x3f, 0x4d, 0x46, 0x4a, 0x49, 0x4e, 0x42, 0x3c, 0x4a, 0x41, + 0x4c, 0x40, 0x4d, 0x5a, 0x49, 0x46, 0x51, 0x46, 0x4b, 0x4c, 0x46, 0x62, + 0x45, 0x42, 0x51, 0x4e, 0x4d, 0x3e, 0x4d, 0x5b, 0x4d, 0x43, 0x45, 0x50, + 0x4b, 0x40, 0x50, 0x53, 0x4f, 0x4f, 0x51, 0x53, 0x46, 0x41, 0x4e, 0x3a, + 0x4b, 0x47, 0x3f, 0x3e, 0x4d, 0x48, 0x53, 0x3f, 0x45, 0x42, 0x4c, 0x45, + 0x55, 0x4c, 0x4b, 0x39, 0x4a, 0x45, 0x48, 0x4d, 0x47, 0x40, 0x48, 0x4f, + 0x4d, 0x49, 0x3e, 0x41, 0x46, 0x4e, 0x40, 0x49, 0x4b, 0x47, 0x4c, 0x45, + 0x44, 0x51, 0x4f, 0x4b, 0x48, 0x49, 0x44, 0x41, 0x43, 0x46, 0x51, 0x45, + 0x40, 0x48, 0x4b, 0x42, 0x44, 0x4f, 0x53, 0x4d, 0x44, 0x46, 0x4e, 0x4c, + 0x48, 0x50, 0x41, 0x45, 0x42, 0x48, 0x4d, 0x4d, 0x47, 0x45, 0x41, 0x45, + 0x48, 0x58, 0x4e, 0x46, 0x43, 0x53, 0x57, 0x52, 0x5e, 0x42, 0x45, 0x4e, + 0x39, 0x24, 0x32, 0x56, 0x47, 0x56, 0x49, 0x52, 0x46, 0x26, 0x3a, 0x51, + 0x4b, 0x05, 0x3e, 0x43, 0x3f, 0x38, 0x4d, 0x4b, 0x4f, 0x27, 0x51, 0x46, + 0x47, 0x41, 0x4a, 0x47, 0x4a, 0x3e, 0x44, 0x51, 0x3f, 0x3a, 0x43, 0x46, + 0x4d, 0x49, 0x46, 0x52, 0x43, 0x48, 0x49, 0x3e, 0x47, 0x46, 0x4a, 0x4d, + 0x47, 0x46, 0x52, 0x50, 0x44, 0x48, 0x4c, 0x47, 0x45, 0x41, 0x49, 0x5b, + 0x4d, 0x4b, 0x47, 0x4c, 0x4a, 0x47, 0x45, 0x5b, 0x49, 0x46, 0x52, 0x47, + 0x47, 0x3d, 0x55, 0x59, 0x40, 0x4b, 0x3e, 0x50, 0x42, 0x43, 0x40, 0x4f, + 0x48, 0x3f, 0x47, 0x53, 0x4d, 0x44, 0x4e, 0x37, 0x4c, 0x43, 0x51, 0x4d, + 0x46, 0x4e, 0x40, 0x41, 0x52, 0x44, 0x43, 0x4a, 0x50, 0x48, 0x47, 0x42, + 0x48, 0x45, 0x50, 0x4d, 0x42, 0x52, 0x44, 0x43, 0x45, 0x43, 0x4c, 0x4d, + 0x44, 0x51, 0x47, 0x48, 0x51, 0x4f, 0x48, 0x45, 0x49, 0x4a, 0x3e, 0x43, + 0x4d, 0x4e, 0x4e, 0x46, 0x54, 0x4d, 0x49, 0x4d, 0x47, 0x46, 0x4b, 0x41, + 0x4a, 0x49, 0x44, 0x45, 0x4d, 0x3e, 0x53, 0x50, 0x47, 0x4d, 0x4e, 0x43, + 0x4f, 0x45, 0x4e, 0x4a, 0x47, 0x49, 0x4c, 0x4c, 0x4d, 0x54, 0x42, 0x4c, + 0x43, 0x5d, 0x59, 0x50, 0x5e, 0x4b, 0x44, 0x43, 0x3c, 0x25, 0x31, 0x5b, + 0x46, 0x5a, 0x50, 0x4d, 0x41, 0x2a, 0x41, 0x4f, 0x44, 0x00, 0x41, 0x3d, + 0x43, 0x4b, 0x47, 0x45, 0x4e, 0x2e, 0x44, 0x46, 0x53, 0x3d, 0x43, 0x41, + 0x44, 0x46, 0x49, 0x42, 0x45, 0x4f, 0x4d, 0x3a, 0x43, 0x3c, 0x47, 0x53, + 0x43, 0x4e, 0x3f, 0x41, 0x4d, 0x50, 0x4b, 0x4c, 0x51, 0x47, 0x53, 0x4f, + 0x45, 0x4a, 0x44, 0x45, 0x41, 0x46, 0x47, 0x50, 0x51, 0x3f, 0x3e, 0x41, + 0x48, 0x45, 0x46, 0x5d, 0x45, 0x4a, 0x4c, 0x46, 0x4a, 0x49, 0x50, 0x51, + 0x51, 0x4c, 0x4f, 0x47, 0x47, 0x42, 0x45, 0x47, 0x4e, 0x48, 0x46, 0x40, + 0x45, 0x46, 0x4d, 0x3b, 0x4d, 0x52, 0x4c, 0x51, 0x49, 0x51, 0x47, 0x3d, + 0x4d, 0x42, 0x4f, 0x4e, 0x43, 0x43, 0x45, 0x3a, 0x42, 0x50, 0x4c, 0x4a, + 0x41, 0x53, 0x4c, 0x45, 0x51, 0x3f, 0x54, 0x43, 0x4b, 0x54, 0x56, 0x4d, + 0x4f, 0x4a, 0x50, 0x4b, 0x44, 0x45, 0x4f, 0x4f, 0x47, 0x3e, 0x50, 0x4f, + 0x4b, 0x48, 0x4d, 0x49, 0x55, 0x4d, 0x45, 0x4d, 0x4a, 0x53, 0x43, 0x46, + 0x4c, 0x45, 0x41, 0x46, 0x49, 0x49, 0x4f, 0x4b, 0x49, 0x50, 0x52, 0x49, + 0x41, 0x54, 0x44, 0x4c, 0x44, 0x63, 0x4a, 0x49, 0x40, 0x59, 0x52, 0x52, + 0x59, 0x3f, 0x3e, 0x3e, 0x40, 0x25, 0x3c, 0x5c, 0x4f, 0x57, 0x44, 0x50, + 0x41, 0x2a, 0x48, 0x4f, 0x43, 0x08, 0x47, 0x43, 0x49, 0x48, 0x4d, 0x49, + 0x46, 0x2b, 0x48, 0x44, 0x4e, 0x47, 0x47, 0x43, 0x44, 0x3e, 0x4a, 0x52, + 0x3f, 0x4a, 0x53, 0x42, 0x49, 0x47, 0x4c, 0x50, 0x43, 0x46, 0x46, 0x3c, + 0x4c, 0x47, 0x4e, 0x4d, 0x42, 0x41, 0x53, 0x52, 0x4f, 0x40, 0x54, 0x50, + 0x46, 0x43, 0x50, 0x56, 0x51, 0x48, 0x48, 0x48, 0x49, 0x39, 0x47, 0x5e, + 0x4e, 0x4b, 0x4f, 0x4e, 0x43, 0x45, 0x42, 0x58, 0x4a, 0x3b, 0x48, 0x4d, + 0x43, 0x3e, 0x4b, 0x43, 0x3c, 0x45, 0x46, 0x4b, 0x42, 0x42, 0x4e, 0x3d, + 0x4b, 0x4e, 0x51, 0x52, 0x48, 0x3e, 0x4b, 0x3f, 0x4c, 0x4a, 0x4b, 0x4c, + 0x46, 0x48, 0x3e, 0x48, 0x47, 0x4d, 0x4a, 0x46, 0x49, 0x4d, 0x4a, 0x48, + 0x50, 0x4b, 0x40, 0x48, 0x4b, 0x52, 0x46, 0x50, 0x4f, 0x3e, 0x42, 0x44, + 0x44, 0x42, 0x43, 0x49, 0x4f, 0x4f, 0x46, 0x42, 0x4a, 0x54, 0x42, 0x48, + 0x50, 0x4f, 0x4f, 0x4c, 0x4c, 0x47, 0x52, 0x49, 0x4c, 0x45, 0x4a, 0x4d, + 0x4a, 0x41, 0x47, 0x4a, 0x4d, 0x4a, 0x4c, 0x46, 0x51, 0x44, 0x4b, 0x49, + 0x53, 0x5e, 0x45, 0x4a, 0x3b, 0x57, 0x5a, 0x4c, 0x59, 0x43, 0x3e, 0x4a, + 0x3e, 0x20, 0x36, 0x5d, 0x47, 0x5b, 0x3f, 0x55, 0x3e, 0x24, 0x41, 0x52, + 0x3f, 0x01, 0x49, 0x41, 0x40, 0x45, 0x42, 0x46, 0x49, 0x2a, 0x47, 0x40, + 0x44, 0x3f, 0x42, 0x47, 0x4e, 0x42, 0x4b, 0x3d, 0x45, 0x4c, 0x47, 0x3d, + 0x4c, 0x44, 0x48, 0x43, 0x43, 0x41, 0x4a, 0x3d, 0x48, 0x4b, 0x46, 0x4e, + 0x4c, 0x45, 0x48, 0x4d, 0x54, 0x4d, 0x3e, 0x46, 0x3e, 0x47, 0x44, 0x4e, + 0x48, 0x49, 0x53, 0x4b, 0x41, 0x45, 0x4c, 0x57, 0x52, 0x4e, 0x40, 0x48, + 0x4d, 0x43, 0x44, 0x5a, 0x4a, 0x4c, 0x48, 0x4d, 0x3f, 0x52, 0x41, 0x50, + 0x4a, 0x47, 0x3e, 0x43, 0x4c, 0x42, 0x48, 0x3e, 0x4f, 0x4b, 0x41, 0x43, + 0x49, 0x40, 0x43, 0x36, 0x3f, 0x4b, 0x49, 0x49, 0x51, 0x43, 0x48, 0x40, + 0x4c, 0x51, 0x4d, 0x4a, 0x49, 0x3f, 0x4b, 0x3d, 0x4f, 0x4b, 0x43, 0x4d, + 0x46, 0x40, 0x46, 0x4d, 0x49, 0x48, 0x4d, 0x4c, 0x52, 0x4c, 0x49, 0x4f, + 0x53, 0x40, 0x49, 0x53, 0x47, 0x43, 0x4c, 0x45, 0x42, 0x48, 0x42, 0x4e, + 0x49, 0x43, 0x42, 0x40, 0x4f, 0x46, 0x50, 0x47, 0x51, 0x4a, 0x52, 0x45, + 0x4c, 0x51, 0x48, 0x47, 0x40, 0x41, 0x52, 0x4f, 0x41, 0x5a, 0x53, 0x47, + 0x42, 0x5f, 0x55, 0x4f, 0x53, 0x3e, 0x41, 0x49, 0x3d, 0x20, 0x3f, 0x54, + 0x42, 0x5b, 0x49, 0x4d, 0x3d, 0x22, 0x3e, 0x48, 0x41, 0x01, 0x4c, 0x3d, + 0x43, 0x4a, 0x46, 0x43, 0x4f, 0x2b, 0x49, 0x46, 0x47, 0x4a, 0x51, 0x3d, + 0x4b, 0x44, 0x49, 0x41, 0x47, 0x47, 0x45, 0x3a, 0x44, 0x42, 0x40, 0x52, + 0x46, 0x51, 0x4a, 0x41, 0x4a, 0x52, 0x44, 0x52, 0x4a, 0x40, 0x46, 0x45, + 0x52, 0x4c, 0x4e, 0x42, 0x42, 0x48, 0x40, 0x4f, 0x4b, 0x4f, 0x51, 0x4c, + 0x4e, 0x48, 0x4a, 0x5a, 0x46, 0x3d, 0x41, 0x50, 0x52, 0x4c, 0x44, 0x53, + 0x4b, 0x4d, 0x4f, 0x49, 0x47, 0x4c, 0x48, 0x45, 0x48, 0x4a, 0x44, 0x4e, + 0x4c, 0x40, 0x4d, 0x35, 0x40, 0x49, 0x4a, 0x51, 0x49, 0x4a, 0x46, 0x36, + 0x46, 0x47, 0x4a, 0x4c, 0x40, 0x4e, 0x42, 0x38, 0x48, 0x45, 0x42, 0x49, + 0x54, 0x4c, 0x3f, 0x49, 0x4c, 0x39, 0x47, 0x45, 0x4e, 0x4a, 0x42, 0x44, + 0x4b, 0x53, 0x43, 0x40, 0x46, 0x51, 0x3d, 0x50, 0x4b, 0x43, 0x4a, 0x4c, + 0x55, 0x54, 0x4a, 0x43, 0x48, 0x40, 0x44, 0x3f, 0x47, 0x45, 0x3e, 0x41, + 0x49, 0x44, 0x4d, 0x49, 0x44, 0x41, 0x4a, 0x50, 0x44, 0x49, 0x4d, 0x47, + 0x4a, 0x49, 0x46, 0x49, 0x40, 0x5b, 0x4d, 0x51, 0x47, 0x57, 0x49, 0x4f, + 0x56, 0x46, 0x3a, 0x4a, 0x3e, 0x22, 0x36, 0x5c, 0x44, 0x56, 0x46, 0x48, + 0x3a, 0x2d, 0x4a, 0x48, 0x44, 0x17, 0x41, 0x42, 0x40, 0x3d, 0x4e, 0x45, + 0x40, 0x26, 0x43, 0x52, 0x41, 0x40, 0x44, 0x4a, 0x48, 0x42, 0x4f, 0x47, + 0x46, 0x4c, 0x4a, 0x3b, 0x42, 0x3e, 0x3e, 0x49, 0x4e, 0x44, 0x4e, 0x49, + 0x47, 0x41, 0x47, 0x44, 0x4c, 0x45, 0x4d, 0x49, 0x49, 0x48, 0x55, 0x3d, + 0x4a, 0x45, 0x50, 0x4f, 0x46, 0x4c, 0x46, 0x45, 0x3c, 0x51, 0x4b, 0x5a, + 0x46, 0x47, 0x54, 0x41, 0x44, 0x40, 0x4f, 0x53, 0x49, 0x46, 0x46, 0x48, + 0x44, 0x40, 0x50, 0x49, 0x49, 0x43, 0x50, 0x41, 0x52, 0x4b, 0x46, 0x3e, + 0x44, 0x44, 0x46, 0x4e, 0x47, 0x48, 0x3e, 0x38, 0x4c, 0x4c, 0x48, 0x43, + 0x48, 0x3e, 0x50, 0x42, 0x51, 0x50, 0x4a, 0x48, 0x4a, 0x42, 0x44, 0x3d, + 0x4a, 0x46, 0x46, 0x3d, 0x4e, 0x47, 0x3d, 0x48, 0x4c, 0x46, 0x50, 0x4d, + 0x49, 0x45, 0x4a, 0x4c, 0x4c, 0x47, 0x4a, 0x42, 0x4a, 0x45, 0x50, 0x52, + 0x4b, 0x4d, 0x4c, 0x43, 0x42, 0x53, 0x41, 0x45, 0x49, 0x41, 0x4b, 0x4c, + 0x52, 0x54, 0x4b, 0x41, 0x48, 0x4c, 0x47, 0x4c, 0x41, 0x49, 0x4a, 0x47, + 0x50, 0x59, 0x4e, 0x45, 0x3c, 0x5d, 0x53, 0x4c, 0x5a, 0x3e, 0x3a, 0x51, + 0x3a, 0x22, 0x35, 0x59, 0x40, 0x5a, 0x43, 0x46, 0x41, 0x32, 0x44, 0x4b, + 0x47, 0x04, 0x4c, 0x3a, 0x4a, 0x49, 0x48, 0x3d, 0x45, 0x2b, 0x50, 0x41, + 0x3e, 0x44, 0x4f, 0x43, 0x4a, 0x3f, 0x48, 0x4b, 0x53, 0x49, 0x4b, 0x38, + 0x44, 0x40, 0x48, 0x4c, 0x41, 0x3f, 0x47, 0x3e, 0x47, 0x49, 0x45, 0x42, + 0x43, 0x3e, 0x46, 0x44, 0x53, 0x4d, 0x48, 0x44, 0x45, 0x42, 0x43, 0x53, + 0x55, 0x49, 0x4d, 0x4b, 0x45, 0x44, 0x47, 0x5f, 0x48, 0x44, 0x4a, 0x48, + 0x45, 0x4d, 0x4f, 0x5e, 0x4e, 0x46, 0x49, 0x49, 0x4d, 0x49, 0x44, 0x48, + 0x4d, 0x41, 0x50, 0x48, 0x3d, 0x3f, 0x4d, 0x38, 0x46, 0x4a, 0x50, 0x4a, + 0x45, 0x3e, 0x43, 0x36, 0x42, 0x48, 0x53, 0x54, 0x49, 0x43, 0x4b, 0x3a, + 0x45, 0x48, 0x50, 0x45, 0x4a, 0x4c, 0x4a, 0x4d, 0x43, 0x4c, 0x55, 0x4e, + 0x4c, 0x42, 0x45, 0x52, 0x52, 0x45, 0x46, 0x40, 0x54, 0x4c, 0x3d, 0x4e, + 0x49, 0x4e, 0x44, 0x47, 0x45, 0x48, 0x4b, 0x50, 0x49, 0x4b, 0x44, 0x4b, + 0x4f, 0x49, 0x47, 0x47, 0x53, 0x3f, 0x4b, 0x42, 0x45, 0x3e, 0x4d, 0x4d, + 0x48, 0x51, 0x45, 0x40, 0x43, 0x43, 0x4e, 0x44, 0x51, 0x55, 0x4a, 0x3e, + 0x45, 0x55, 0x58, 0x50, 0x50, 0x38, 0x44, 0x4f, 0x3b, 0x23, 0x3c, 0x55, + 0x3c, 0x54, 0x49, 0x42, 0x44, 0x2f, 0x3e, 0x47, 0x42, 0x01, 0x42, 0x37, + 0x3f, 0x42, 0x45, 0x45, 0x47, 0x2a, 0x52, 0x4b, 0x45, 0x3c, 0x47, 0x44, + 0x44, 0x40, 0x50, 0x53, 0x48, 0x42, 0x4d, 0x36, 0x50, 0x3d, 0x49, 0x44, + 0x4f, 0x4c, 0x4a, 0x42, 0x4d, 0x3e, 0x3d, 0x3f, 0x4e, 0x44, 0x4d, 0x4e, + 0x54, 0x3d, 0x42, 0x46, 0x49, 0x47, 0x4b, 0x53, 0x45, 0x46, 0x47, 0x4a, + 0x45, 0x3d, 0x4a, 0x5f, 0x51, 0x3e, 0x45, 0x45, 0x44, 0x3a, 0x4d, 0x57, + 0x45, 0x47, 0x4d, 0x45, 0x4e, 0x4b, 0x51, 0x48, 0x4b, 0x4a, 0x3c, 0x4e, + 0x51, 0x41, 0x4d, 0x36, 0x47, 0x4a, 0x46, 0x51, 0x4e, 0x4c, 0x52, 0x41, + 0x55, 0x47, 0x41, 0x47, 0x4d, 0x47, 0x4b, 0x3d, 0x4a, 0x4a, 0x46, 0x49, + 0x4d, 0x48, 0x46, 0x46, 0x4d, 0x52, 0x52, 0x48, 0x49, 0x3f, 0x4b, 0x4e, + 0x4c, 0x49, 0x45, 0x47, 0x41, 0x4b, 0x44, 0x48, 0x52, 0x4b, 0x53, 0x44, + 0x46, 0x4e, 0x44, 0x49, 0x52, 0x50, 0x46, 0x4b, 0x44, 0x43, 0x50, 0x49, + 0x4a, 0x53, 0x45, 0x49, 0x52, 0x3f, 0x4a, 0x4e, 0x49, 0x4c, 0x4d, 0x4d, + 0x40, 0x40, 0x3f, 0x4a, 0x47, 0x56, 0x51, 0x43, 0x40, 0x5a, 0x58, 0x52, + 0x4f, 0x3d, 0x3d, 0x45, 0x38, 0x29, 0x33, 0x59, 0x45, 0x54, 0x3c, 0x42, + 0x3f, 0x27, 0x3e, 0x49, 0x48, 0x06, 0x4a, 0x3f, 0x41, 0x49, 0x4c, 0x48, + 0x46, 0x2b, 0x4a, 0x4f, 0x44, 0x46, 0x4c, 0x46, 0x4a, 0x3b, 0x4d, 0x4a, + 0x40, 0x41, 0x45, 0x38, 0x51, 0x39, 0x46, 0x46, 0x41, 0x51, 0x4e, 0x41, + 0x49, 0x44, 0x48, 0x4a, 0x4b, 0x46, 0x47, 0x46, 0x4a, 0x4c, 0x47, 0x48, + 0x3d, 0x42, 0x50, 0x4f, 0x50, 0x4a, 0x4a, 0x48, 0x4a, 0x45, 0x45, 0x61, + 0x4a, 0x4c, 0x49, 0x3d, 0x4b, 0x4a, 0x4a, 0x5a, 0x48, 0x49, 0x50, 0x4f, + 0x42, 0x48, 0x3e, 0x44, 0x43, 0x3b, 0x4f, 0x54, 0x4b, 0x4a, 0x47, 0x31, + 0x4a, 0x49, 0x47, 0x4e, 0x48, 0x48, 0x46, 0x42, 0x4a, 0x45, 0x4c, 0x49, + 0x4b, 0x4e, 0x53, 0x43, 0x4c, 0x49, 0x4f, 0x4b, 0x46, 0x4c, 0x4b, 0x4e, + 0x51, 0x4b, 0x49, 0x52, 0x44, 0x55, 0x45, 0x49, 0x4b, 0x4a, 0x50, 0x4c, + 0x4d, 0x4a, 0x4b, 0x48, 0x41, 0x46, 0x47, 0x43, 0x4b, 0x3f, 0x54, 0x4a, + 0x46, 0x49, 0x51, 0x48, 0x4e, 0x4a, 0x41, 0x52, 0x52, 0x4e, 0x53, 0x47, + 0x42, 0x48, 0x43, 0x44, 0x54, 0x51, 0x40, 0x49, 0x4c, 0x48, 0x49, 0x44, + 0x4c, 0x56, 0x52, 0x49, 0x3d, 0x59, 0x4f, 0x56, 0x56, 0x42, 0x46, 0x45, + 0x3e, 0x28, 0x3f, 0x5b, 0x3f, 0x5a, 0x4c, 0x42, 0x44, 0x22, 0x3f, 0x46, + 0x47, 0x0d, 0x3e, 0x41, 0x45, 0x49, 0x4a, 0x3b, 0x45, 0x2d, 0x4d, 0x4a, + 0x44, 0x43, 0x49, 0x46, 0x4b, 0x47, 0x49, 0x45, 0x4e, 0x40, 0x4c, 0x3c, + 0x42, 0x3e, 0x4b, 0x50, 0x48, 0x49, 0x4c, 0x42, 0x3c, 0x43, 0x50, 0x43, + 0x49, 0x4e, 0x4e, 0x43, 0x46, 0x4c, 0x48, 0x4a, 0x43, 0x4c, 0x49, 0x4e, + 0x47, 0x44, 0x50, 0x4c, 0x4a, 0x48, 0x47, 0x5f, 0x3f, 0x3e, 0x48, 0x4f, + 0x4f, 0x49, 0x4a, 0x5f, 0x4e, 0x40, 0x4e, 0x48, 0x47, 0x44, 0x40, 0x4d, + 0x3f, 0x4a, 0x53, 0x45, 0x3e, 0x50, 0x3f, 0x39, 0x50, 0x45, 0x45, 0x4b, + 0x43, 0x41, 0x46, 0x41, 0x49, 0x47, 0x4b, 0x41, 0x3c, 0x4b, 0x46, 0x3f, + 0x41, 0x4a, 0x4e, 0x4c, 0x49, 0x4c, 0x3f, 0x44, 0x53, 0x4c, 0x45, 0x49, + 0x48, 0x4d, 0x48, 0x4a, 0x48, 0x4f, 0x45, 0x4d, 0x48, 0x4c, 0x41, 0x49, + 0x42, 0x48, 0x53, 0x46, 0x4a, 0x46, 0x4b, 0x4f, 0x4c, 0x52, 0x4c, 0x51, + 0x41, 0x4d, 0x49, 0x41, 0x49, 0x4f, 0x49, 0x42, 0x4a, 0x48, 0x51, 0x4a, + 0x44, 0x4d, 0x55, 0x48, 0x47, 0x4d, 0x4d, 0x45, 0x42, 0x60, 0x4a, 0x51, + 0x42, 0x54, 0x56, 0x56, 0x50, 0x4a, 0x3f, 0x4a, 0x40, 0x25, 0x3a, 0x59, + 0x46, 0x58, 0x52, 0x46, 0x41, 0x28, 0x3d, 0x3e, 0x45, 0x13, 0x47, 0x41, + 0x3d, 0x44, 0x48, 0x45, 0x49, 0x26, 0x46, 0x4c, 0x3b, 0x4a, 0x42, 0x47, + 0x46, 0x41, 0x44, 0x52, 0x50, 0x4a, 0x4f, 0x40, 0x4b, 0x39, 0x42, 0x45, + 0x4a, 0x4d, 0x4f, 0x3f, 0x42, 0x4f, 0x49, 0x45, 0x42, 0x4a, 0x46, 0x47, + 0x48, 0x40, 0x4a, 0x46, 0x41, 0x3b, 0x48, 0x55, 0x4b, 0x4e, 0x4e, 0x48, + 0x4b, 0x44, 0x46, 0x53, 0x48, 0x45, 0x4b, 0x53, 0x49, 0x43, 0x4a, 0x5c, + 0x46, 0x45, 0x45, 0x49, 0x49, 0x49, 0x4c, 0x43, 0x4e, 0x4a, 0x41, 0x4a, + 0x42, 0x43, 0x4a, 0x38, 0x44, 0x4a, 0x4b, 0x3f, 0x45, 0x49, 0x45, 0x38, + 0x43, 0x40, 0x45, 0x4c, 0x47, 0x42, 0x3f, 0x42, 0x3e, 0x4a, 0x43, 0x50, + 0x4a, 0x4e, 0x4f, 0x47, 0x4d, 0x49, 0x49, 0x47, 0x4a, 0x4d, 0x46, 0x4c, + 0x4f, 0x3d, 0x52, 0x4a, 0x41, 0x44, 0x4b, 0x50, 0x4c, 0x52, 0x49, 0x50, + 0x4b, 0x45, 0x49, 0x4d, 0x48, 0x55, 0x50, 0x47, 0x4e, 0x50, 0x4f, 0x48, + 0x46, 0x4d, 0x4d, 0x41, 0x48, 0x51, 0x4b, 0x4c, 0x47, 0x51, 0x42, 0x42, + 0x4d, 0x47, 0x43, 0x4c, 0x4c, 0x5a, 0x4e, 0x47, 0x3b, 0x59, 0x51, 0x57, + 0x4c, 0x40, 0x46, 0x4c, 0x37, 0x2a, 0x35, 0x58, 0x44, 0x5b, 0x4c, 0x44, + 0x3e, 0x2e, 0x3f, 0x43, 0x46, 0x23, 0x49, 0x3e, 0x41, 0x3f, 0x4b, 0x3e, + 0x4e, 0x2f, 0x4d, 0x4a, 0x4e, 0x40, 0x4e, 0x41, 0x40, 0x3f, 0x4a, 0x42, + 0x4d, 0x4c, 0x44, 0x47, 0x4e, 0x44, 0x40, 0x43, 0x4d, 0x49, 0x4f, 0x3d, + 0x49, 0x3f, 0x51, 0x48, 0x42, 0x4a, 0x49, 0x47, 0x49, 0x46, 0x4a, 0x45, + 0x45, 0x49, 0x53, 0x4d, 0x4c, 0x4e, 0x44, 0x50, 0x4b, 0x43, 0x4e, 0x5f, + 0x3c, 0x40, 0x44, 0x46, 0x48, 0x4b, 0x42, 0x62, 0x4e, 0x50, 0x4c, 0x49, + 0x4a, 0x4f, 0x44, 0x53, 0x42, 0x43, 0x49, 0x48, 0x4b, 0x3c, 0x4a, 0x37, + 0x4c, 0x41, 0x49, 0x46, 0x46, 0x47, 0x43, 0x40, 0x4d, 0x4d, 0x4a, 0x48, + 0x50, 0x4b, 0x50, 0x41, 0x44, 0x3e, 0x51, 0x47, 0x44, 0x4a, 0x44, 0x45, + 0x48, 0x4d, 0x52, 0x4e, 0x44, 0x48, 0x4d, 0x43, 0x42, 0x45, 0x48, 0x52, + 0x44, 0x42, 0x50, 0x42, 0x4d, 0x45, 0x48, 0x4d, 0x4f, 0x4e, 0x45, 0x49, + 0x51, 0x48, 0x4f, 0x53, 0x4d, 0x4c, 0x48, 0x50, 0x4e, 0x4d, 0x50, 0x48, + 0x49, 0x42, 0x4c, 0x42, 0x4b, 0x4b, 0x49, 0x48, 0x48, 0x49, 0x4a, 0x54, + 0x44, 0x57, 0x4d, 0x4b, 0x3f, 0x56, 0x53, 0x5c, 0x50, 0x4e, 0x46, 0x49, + 0x40, 0x24, 0x44, 0x58, 0x49, 0x54, 0x48, 0x49, 0x41, 0x22, 0x44, 0x3f, + 0x48, 0x1c, 0x4d, 0x39, 0x3e, 0x4c, 0x3d, 0x4a, 0x48, 0x2d, 0x48, 0x3e, + 0x3f, 0x3a, 0x46, 0x4e, 0x44, 0x43, 0x49, 0x51, 0x4d, 0x3c, 0x44, 0x41, + 0x4e, 0x44, 0x42, 0x4c, 0x45, 0x48, 0x45, 0x46, 0x42, 0x46, 0x47, 0x42, + 0x4f, 0x45, 0x47, 0x44, 0x48, 0x47, 0x4a, 0x42, 0x4d, 0x48, 0x3e, 0x53, + 0x47, 0x4b, 0x44, 0x4b, 0x45, 0x4a, 0x50, 0x55, 0x4c, 0x45, 0x48, 0x43, + 0x53, 0x3d, 0x4e, 0x5f, 0x42, 0x44, 0x4a, 0x4f, 0x3f, 0x48, 0x4e, 0x4b, + 0x43, 0x48, 0x43, 0x41, 0x4a, 0x4b, 0x51, 0x39, 0x52, 0x46, 0x44, 0x49, + 0x48, 0x45, 0x4c, 0x40, 0x45, 0x49, 0x51, 0x48, 0x45, 0x42, 0x45, 0x48, + 0x40, 0x43, 0x3d, 0x47, 0x53, 0x54, 0x4d, 0x4a, 0x4a, 0x47, 0x48, 0x43, + 0x4c, 0x46, 0x43, 0x4f, 0x49, 0x4c, 0x3f, 0x3d, 0x4b, 0x41, 0x40, 0x48, + 0x4e, 0x4c, 0x4b, 0x40, 0x4c, 0x43, 0x49, 0x4d, 0x47, 0x4f, 0x47, 0x42, + 0x47, 0x4a, 0x4d, 0x4f, 0x46, 0x4d, 0x51, 0x49, 0x48, 0x4d, 0x4e, 0x46, + 0x47, 0x41, 0x44, 0x4d, 0x4b, 0x55, 0x4b, 0x4c, 0x41, 0x5e, 0x50, 0x45, + 0x40, 0x55, 0x4b, 0x60, 0x55, 0x47, 0x3d, 0x4a, 0x42, 0x22, 0x46, 0x5a, + 0x47, 0x53, 0x49, 0x44, 0x44, 0x27, 0x41, 0x4f, 0x3e, 0x22, 0x4a, 0x44, + 0x49, 0x3e, 0x4e, 0x4d, 0x3f, 0x3a, 0x4c, 0x44, 0x4a, 0x44, 0x46, 0x51, + 0x4f, 0x42, 0x4c, 0x4e, 0x39, 0x4b, 0x42, 0x39, 0x4b, 0x3e, 0x4f, 0x47, + 0x4a, 0x4f, 0x3f, 0x4d, 0x43, 0x4c, 0x4a, 0x4b, 0x4b, 0x3d, 0x51, 0x46, + 0x49, 0x4c, 0x47, 0x44, 0x43, 0x3d, 0x3c, 0x54, 0x4a, 0x47, 0x4d, 0x50, + 0x4a, 0x46, 0x51, 0x62, 0x46, 0x4d, 0x4b, 0x46, 0x49, 0x3c, 0x50, 0x57, + 0x47, 0x40, 0x3e, 0x4c, 0x4b, 0x3f, 0x55, 0x46, 0x3d, 0x45, 0x42, 0x4e, + 0x50, 0x49, 0x46, 0x3a, 0x4c, 0x47, 0x4a, 0x49, 0x42, 0x42, 0x4a, 0x44, + 0x42, 0x40, 0x49, 0x54, 0x46, 0x4b, 0x47, 0x45, 0x51, 0x47, 0x41, 0x42, + 0x49, 0x50, 0x4e, 0x48, 0x4b, 0x4b, 0x47, 0x4a, 0x47, 0x49, 0x4b, 0x45, + 0x4b, 0x54, 0x48, 0x54, 0x4b, 0x49, 0x51, 0x4a, 0x4a, 0x40, 0x46, 0x42, + 0x44, 0x44, 0x4d, 0x4b, 0x47, 0x43, 0x45, 0x41, 0x3e, 0x49, 0x43, 0x51, + 0x3e, 0x4b, 0x52, 0x46, 0x48, 0x3f, 0x4e, 0x51, 0x51, 0x49, 0x3f, 0x48, + 0x4c, 0x4c, 0x52, 0x47, 0x43, 0x57, 0x44, 0x42, 0x40, 0x52, 0x50, 0x5d, + 0x4f, 0x40, 0x42, 0x45, 0x46, 0x26, 0x3c, 0x51, 0x4b, 0x4e, 0x4b, 0x49, + 0x46, 0x35, 0x49, 0x53, 0x49, 0x2b, 0x4d, 0x3e, 0x50, 0x44, 0x4f, 0x54, + 0x46, 0x34, 0x49, 0x4d, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x44, 0x52, 0x41, + 0x4d, 0x4c, 0x52, 0x41, 0x49, 0x3a, 0x4e, 0x49, 0x40, 0x4b, 0x45, 0x4d, + 0x4b, 0x4a, 0x47, 0x49, 0x45, 0x49, 0x4d, 0x50, 0x3e, 0x47, 0x44, 0x51, + 0x4c, 0x41, 0x45, 0x50, 0x47, 0x41, 0x4a, 0x52, 0x4b, 0x3d, 0x4b, 0x5b, + 0x4c, 0x4c, 0x4d, 0x3f, 0x47, 0x44, 0x49, 0x5d, 0x4a, 0x53, 0x44, 0x45, + 0x45, 0x46, 0x3d, 0x4f, 0x50, 0x3b, 0x44, 0x4e, 0x40, 0x41, 0x4c, 0x3a, + 0x4a, 0x45, 0x49, 0x48, 0x45, 0x4a, 0x45, 0x36, 0x45, 0x4d, 0x4c, 0x49, + 0x3f, 0x47, 0x4d, 0x40, 0x53, 0x48, 0x49, 0x4c, 0x47, 0x4f, 0x42, 0x44, + 0x45, 0x40, 0x4a, 0x4c, 0x49, 0x4f, 0x4b, 0x4d, 0x42, 0x45, 0x3e, 0x4a, + 0x48, 0x4a, 0x49, 0x50, 0x4c, 0x53, 0x50, 0x45, 0x4b, 0x4c, 0x46, 0x4f, + 0x44, 0x43, 0x54, 0x50, 0x3f, 0x48, 0x42, 0x4b, 0x43, 0x3f, 0x4d, 0x4c, + 0x43, 0x49, 0x4a, 0x47, 0x54, 0x4b, 0x4f, 0x4d, 0x44, 0x47, 0x49, 0x4e, + 0x4e, 0x55, 0x40, 0x46, 0x44, 0x56, 0x4e, 0x65, 0x4f, 0x3f, 0x43, 0x48, + 0x39, 0x27, 0x43, 0x55, 0x4b, 0x4c, 0x44, 0x46, 0x42, 0x34, 0x44, 0x52, + 0x43, 0x22, 0x4e, 0x41, 0x49, 0x48, 0x49, 0x51, 0x3b, 0x37, 0x4b, 0x40, + 0x4f, 0x45, 0x53, 0x4c, 0x47, 0x46, 0x47, 0x4c, 0x3e, 0x44, 0x45, 0x49, + 0x48, 0x50, 0x45, 0x40, 0x46, 0x4c, 0x47, 0x4d, 0x44, 0x48, 0x49, 0x50, + 0x4f, 0x4a, 0x46, 0x55, 0x4e, 0x42, 0x4c, 0x4c, 0x50, 0x48, 0x3d, 0x55, + 0x46, 0x3e, 0x4a, 0x4b, 0x4f, 0x46, 0x46, 0x60, 0x50, 0x3f, 0x55, 0x40, + 0x42, 0x44, 0x48, 0x63, 0x50, 0x3d, 0x45, 0x4f, 0x4e, 0x41, 0x47, 0x48, + 0x4a, 0x3c, 0x3d, 0x46, 0x3f, 0x42, 0x43, 0x37, 0x4f, 0x4f, 0x50, 0x47, + 0x47, 0x4b, 0x52, 0x40, 0x3f, 0x44, 0x4a, 0x40, 0x4d, 0x44, 0x4e, 0x37, + 0x43, 0x48, 0x47, 0x3f, 0x51, 0x4d, 0x45, 0x42, 0x41, 0x46, 0x3d, 0x53, + 0x4f, 0x4b, 0x54, 0x45, 0x51, 0x40, 0x4a, 0x4a, 0x48, 0x4f, 0x43, 0x4a, + 0x4f, 0x4c, 0x4c, 0x4f, 0x48, 0x4c, 0x44, 0x4e, 0x43, 0x46, 0x4f, 0x4a, + 0x43, 0x41, 0x49, 0x49, 0x47, 0x53, 0x45, 0x49, 0x4e, 0x46, 0x4c, 0x4e, + 0x3c, 0x49, 0x44, 0x45, 0x4c, 0x42, 0x49, 0x41, 0x48, 0x58, 0x54, 0x4d, + 0x35, 0x52, 0x4e, 0x5b, 0x4f, 0x40, 0x3e, 0x46, 0x46, 0x36, 0x3d, 0x60, + 0x4d, 0x49, 0x4a, 0x43, 0x44, 0x36, 0x49, 0x67, 0x4a, 0x2d, 0x4b, 0x40, + 0x3f, 0x49, 0x43, 0x5f, 0x45, 0x3c, 0x49, 0x4c, 0x4a, 0x43, 0x48, 0x55, + 0x49, 0x46, 0x49, 0x46, 0x44, 0x4e, 0x42, 0x4e, 0x40, 0x45, 0x42, 0x52, + 0x4a, 0x40, 0x4a, 0x44, 0x40, 0x45, 0x54, 0x3d, 0x4c, 0x3e, 0x4c, 0x55, + 0x4d, 0x45, 0x4d, 0x51, 0x4a, 0x4b, 0x44, 0x5b, 0x48, 0x3d, 0x3e, 0x46, + 0x4f, 0x4d, 0x3f, 0x62, 0x4d, 0x45, 0x3f, 0x47, 0x47, 0x47, 0x44, 0x5b, + 0x4b, 0x4f, 0x51, 0x4c, 0x4a, 0x47, 0x48, 0x5b, 0x47, 0x40, 0x4a, 0x47, + 0x42, 0x44, 0x46, 0x46, 0x45, 0x48, 0x4a, 0x3f, 0x40, 0x4f, 0x48, 0x3a, + 0x49, 0x52, 0x4a, 0x53, 0x43, 0x4c, 0x4b, 0x4a, 0x4a, 0x4a, 0x4e, 0x42, + 0x4b, 0x46, 0x3d, 0x50, 0x51, 0x4b, 0x4b, 0x4f, 0x50, 0x4c, 0x4f, 0x4c, + 0x4d, 0x41, 0x41, 0x3c, 0x40, 0x43, 0x54, 0x51, 0x48, 0x3d, 0x48, 0x51, + 0x42, 0x42, 0x4c, 0x4e, 0x4d, 0x4b, 0x49, 0x43, 0x48, 0x47, 0x4b, 0x49, + 0x49, 0x4e, 0x4d, 0x46, 0x4c, 0x52, 0x49, 0x49, 0x51, 0x4e, 0x45, 0x47, + 0x44, 0x47, 0x42, 0x4a, 0x46, 0x59, 0x48, 0x48, 0x4b, 0x4f, 0x4c, 0x5e, + 0x5c, 0x45, 0x3f, 0x48, 0x3d, 0x3f, 0x37, 0x5a, 0x4b, 0x4b, 0x45, 0x49, + 0x3e, 0x42, 0x41, 0x6b, 0x49, 0x2d, 0x45, 0x43, 0x47, 0x45, 0x49, 0x61, + 0x3d, 0x3b, 0x49, 0x43, 0x49, 0x4b, 0x4b, 0x55, 0x4b, 0x47, 0x46, 0x46, + 0x48, 0x4d, 0x49, 0x4f, 0x4a, 0x4c, 0x42, 0x51, 0x41, 0x44, 0x45, 0x4f, + 0x4e, 0x44, 0x3f, 0x55, 0x3e, 0x4a, 0x45, 0x50, 0x46, 0x42, 0x41, 0x49, + 0x49, 0x47, 0x49, 0x61, 0x47, 0x40, 0x41, 0x4e, 0x4d, 0x4b, 0x4a, 0x5e, + 0x52, 0x49, 0x4b, 0x52, 0x51, 0x55, 0x42, 0x61, 0x53, 0x4c, 0x48, 0x4a, + 0x4e, 0x48, 0x48, 0x57, 0x4c, 0x40, 0x40, 0x48, 0x45, 0x43, 0x3e, 0x46, + 0x43, 0x4a, 0x45, 0x45, 0x44, 0x4f, 0x44, 0x40, 0x49, 0x48, 0x4e, 0x49, + 0x4a, 0x4e, 0x49, 0x51, 0x46, 0x4f, 0x47, 0x44, 0x42, 0x4d, 0x43, 0x4e, + 0x4f, 0x4d, 0x44, 0x51, 0x47, 0x49, 0x40, 0x57, 0x4b, 0x49, 0x47, 0x4c, + 0x4d, 0x4d, 0x3e, 0x47, 0x45, 0x41, 0x50, 0x4b, 0x4b, 0x45, 0x42, 0x4e, + 0x48, 0x47, 0x4e, 0x4b, 0x56, 0x4c, 0x4f, 0x52, 0x51, 0x49, 0x4d, 0x4a, + 0x4b, 0x52, 0x4d, 0x55, 0x4b, 0x4e, 0x4e, 0x4b, 0x51, 0x57, 0x47, 0x42, + 0x49, 0x48, 0x56, 0x44, 0x52, 0x56, 0x53, 0x5a, 0x63, 0x53, 0x4c, 0x4c, + 0x43, 0x56, 0x3c, 0x57, 0x47, 0x47, 0x4d, 0x52, 0x43, 0x48, 0x45, 0x5f, + 0x45, 0x29, 0x47, 0x45, 0x48, 0x40, 0x41, 0x4b, 0x3f, 0x39, 0x49, 0x4e, + 0x47, 0x55, 0x42, 0x56, 0x4d, 0x43, 0x48, 0x44, 0x45, 0x53, 0x43, 0x46, + 0x49, 0x43, 0x49, 0x4a, 0x40, 0x4e, 0x4a, 0x4a, 0x47, 0x43, 0x45, 0x4d, + 0x4a, 0x47, 0x3f, 0x53, 0x45, 0x43, 0x4b, 0x4c, 0x42, 0x47, 0x47, 0x5f, + 0x48, 0x48, 0x46, 0x44, 0x50, 0x47, 0x41, 0x64, 0x4e, 0x46, 0x49, 0x4a, + 0x4d, 0x55, 0x42, 0x55, 0x46, 0x3d, 0x49, 0x43, 0x52, 0x52, 0x47, 0x52, + 0x4e, 0x46, 0x47, 0x41, 0x49, 0x4d, 0x50, 0x47, 0x42, 0x49, 0x41, 0x42, + 0x4b, 0x48, 0x49, 0x42, 0x4d, 0x48, 0x51, 0x54, 0x43, 0x56, 0x4c, 0x52, + 0x53, 0x4d, 0x54, 0x4a, 0x51, 0x50, 0x48, 0x4c, 0x4e, 0x48, 0x4c, 0x4c, + 0x52, 0x49, 0x4a, 0x4e, 0x4e, 0x41, 0x4f, 0x53, 0x49, 0x52, 0x42, 0x4b, + 0x50, 0x46, 0x50, 0x4a, 0x53, 0x56, 0x46, 0x4f, 0x4b, 0x49, 0x3d, 0x41, + 0x4c, 0x52, 0x42, 0x50, 0x4d, 0x45, 0x4e, 0x51, 0x4b, 0x4c, 0x46, 0x42, + 0x41, 0x4b, 0x40, 0x4a, 0x42, 0x57, 0x4f, 0x43, 0x40, 0x50, 0x4c, 0x51, + 0x4f, 0x48, 0x3a, 0x4e, 0x51, 0x40, 0x49, 0x66, 0x4b, 0x42, 0x48, 0x3c, + 0x5b, 0x47, 0x53, 0x40, 0x4a, 0x48, 0x35, 0x44, 0x5f, 0x50, 0x4a, 0x3c, + 0x41, 0x45, 0x48, 0x3b, 0x42, 0x59, 0x43, 0x4b, 0x48, 0x49, 0x4a, 0x40, + 0x4f, 0x5c, 0x50, 0x54, 0x53, 0x55, 0x4c, 0x4a, 0x43, 0x46, 0x49, 0x47, + 0x49, 0x48, 0x4b, 0x43, 0x42, 0x44, 0x42, 0x46, 0x44, 0x3f, 0x4b, 0x42, + 0x4d, 0x49, 0x41, 0x46, 0x47, 0x51, 0x51, 0x44, 0x4c, 0x54, 0x4e, 0x4b, + 0x42, 0x52, 0x4e, 0x4c, 0x4b, 0x4a, 0x50, 0x4e, 0x44, 0x4b, 0x4e, 0x4e, + 0x4f, 0x42, 0x4b, 0x48, 0x46, 0x43, 0x48, 0x54, 0x4b, 0x4e, 0x48, 0x4f, + 0x4a, 0x4d, 0x43, 0x4e, 0x47, 0x50, 0x4a, 0x44, 0x47, 0x52, 0x46, 0x53, + 0x4a, 0x40, 0x46, 0x54, 0x50, 0x4a, 0x47, 0x51, 0x49, 0x45, 0x4b, 0x4e, + 0x4b, 0x46, 0x4c, 0x4c, 0x52, 0x47, 0x45, 0x45, 0x4a, 0x47, 0x4c, 0x52, + 0x44, 0x51, 0x47, 0x42, 0x47, 0x43, 0x43, 0x49, 0x52, 0x5a, 0x55, 0x3e, + 0x45, 0x4b, 0x4c, 0x46, 0x4f, 0x4b, 0x45, 0x49, 0x4a, 0x4e, 0x4a, 0x50, + 0x3e, 0x4e, 0x42, 0x4e, 0x44, 0x55, 0x3d, 0x4a, 0x4d, 0x49, 0x4d, 0x42, + 0x49, 0x4e, 0x50, 0x44, 0x4b, 0x3c, 0x41, 0x49, 0x51, 0x49, 0x3c, 0x4e, + 0x4c, 0x39, 0x4c, 0x72, 0x44, 0x4b, 0x49, 0x42, 0x5f, 0x48, 0x4a, 0x48, + 0x41, 0x4c, 0x43, 0x40, 0x62, 0x5e, 0x47, 0x3c, 0x4a, 0x4c, 0x55, 0x49, + 0x4b, 0x52, 0x4e, 0x4b, 0x4d, 0x48, 0x4c, 0x3c, 0x3f, 0x4f, 0x4e, 0x48, + 0x45, 0x55, 0x4a, 0x46, 0x48, 0x3d, 0x45, 0x44, 0x4b, 0x4a, 0x46, 0x3a, + 0x4e, 0x44, 0x4d, 0x49, 0x49, 0x49, 0x40, 0x3e, 0x40, 0x47, 0x48, 0x43, + 0x3f, 0x51, 0x46, 0x4c, 0x45, 0x4c, 0x49, 0x44, 0x3e, 0x57, 0x49, 0x4e, + 0x48, 0x3f, 0x48, 0x47, 0x53, 0x4d, 0x50, 0x51, 0x49, 0x42, 0x45, 0x44, + 0x49, 0x49, 0x46, 0x4b, 0x45, 0x49, 0x4f, 0x49, 0x46, 0x48, 0x4c, 0x55, + 0x46, 0x51, 0x48, 0x4a, 0x48, 0x54, 0x4b, 0x5a, 0x4c, 0x47, 0x40, 0x47, + 0x40, 0x55, 0x50, 0x52, 0x4a, 0x4b, 0x4f, 0x49, 0x4b, 0x50, 0x4b, 0x5b, + 0x51, 0x53, 0x4f, 0x4e, 0x49, 0x48, 0x44, 0x52, 0x46, 0x4e, 0x47, 0x48, + 0x44, 0x43, 0x49, 0x55, 0x48, 0x58, 0x4f, 0x46, 0x45, 0x53, 0x45, 0x4a, + 0x4c, 0x4c, 0x49, 0x46, 0x47, 0x4d, 0x41, 0x4d, 0x4f, 0x59, 0x4a, 0x49, + 0x46, 0x4e, 0x44, 0x49, 0x4d, 0x48, 0x54, 0x47, 0x48, 0x4e, 0x48, 0x43, + 0x46, 0x41, 0x46, 0x44, 0x52, 0x46, 0x42, 0x4c, 0x4c, 0x31, 0x4d, 0x6f, + 0x51, 0x4f, 0x4d, 0x43, 0x5c, 0x48, 0x49, 0x49, 0x46, 0x4c, 0x43, 0x3b, + 0x5d, 0x63, 0x58, 0x46, 0x49, 0x45, 0x4e, 0x48, 0x49, 0x5d, 0x45, 0x50, + 0x56, 0x4d, 0x57, 0x37, 0x40, 0x55, 0x43, 0x4b, 0x4e, 0x46, 0x4c, 0x3b, + 0x3d, 0x4b, 0x49, 0x4b, 0x52, 0x47, 0x4d, 0x34, 0x4c, 0x4c, 0x47, 0x4e, + 0x4d, 0x4c, 0x3d, 0x3f, 0x4a, 0x49, 0x44, 0x45, 0x4a, 0x54, 0x43, 0x44, + 0x50, 0x4b, 0x4d, 0x4c, 0x4e, 0x48, 0x46, 0x51, 0x43, 0x48, 0x48, 0x48, + 0x42, 0x44, 0x4e, 0x48, 0x47, 0x45, 0x48, 0x51, 0x53, 0x4a, 0x4f, 0x58, + 0x42, 0x4d, 0x48, 0x4f, 0x4c, 0x45, 0x4a, 0x57, 0x4b, 0x43, 0x4d, 0x4b, + 0x4a, 0x4e, 0x4c, 0x5f, 0x3f, 0x4f, 0x4a, 0x42, 0x4b, 0x48, 0x4d, 0x62, + 0x4f, 0x4b, 0x50, 0x4c, 0x45, 0x49, 0x44, 0x53, 0x4a, 0x4f, 0x45, 0x56, + 0x4b, 0x44, 0x41, 0x53, 0x49, 0x48, 0x4d, 0x49, 0x47, 0x4b, 0x46, 0x4c, + 0x49, 0x4b, 0x4c, 0x54, 0x4f, 0x4b, 0x47, 0x49, 0x44, 0x4a, 0x4e, 0x53, + 0x4f, 0x49, 0x54, 0x4e, 0x4a, 0x48, 0x42, 0x54, 0x51, 0x46, 0x4b, 0x52, + 0x45, 0x48, 0x51, 0x4a, 0x40, 0x4a, 0x50, 0x45, 0x4a, 0x46, 0x49, 0x46, + 0x54, 0x46, 0x42, 0x48, 0x50, 0x36, 0x4a, 0x6b, 0x46, 0x59, 0x51, 0x47, + 0x5f, 0x4d, 0x43, 0x4d, 0x44, 0x4d, 0x42, 0x3b, 0x65, 0x6a, 0x56, 0x48, + 0x4d, 0x4c, 0x52, 0x4a, 0x4d, 0x61, 0x52, 0x4b, 0x47, 0x4f, 0x48, 0x49, + 0x3f, 0x5b, 0x45, 0x51, 0x48, 0x48, 0x4b, 0x3c, 0x3b, 0x4c, 0x54, 0x52, + 0x4f, 0x51, 0x53, 0x31, 0x47, 0x4c, 0x45, 0x4a, 0x42, 0x4b, 0x47, 0x40, + 0x41, 0x49, 0x4c, 0x46, 0x4b, 0x53, 0x46, 0x49, 0x44, 0x4b, 0x4e, 0x4b, + 0x48, 0x51, 0x49, 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x45, 0x43, 0x46, 0x56, + 0x42, 0x4b, 0x49, 0x4e, 0x4e, 0x53, 0x42, 0x5c, 0x4b, 0x46, 0x49, 0x46, + 0x4e, 0x41, 0x42, 0x67, 0x41, 0x49, 0x4d, 0x48, 0x49, 0x4e, 0x3f, 0x61, + 0x48, 0x4a, 0x40, 0x42, 0x4c, 0x51, 0x50, 0x63, 0x49, 0x44, 0x49, 0x47, + 0x45, 0x4d, 0x49, 0x61, 0x3f, 0x48, 0x40, 0x41, 0x49, 0x49, 0x45, 0x57, + 0x45, 0x46, 0x4d, 0x46, 0x4c, 0x4a, 0x4d, 0x4b, 0x43, 0x54, 0x4b, 0x49, + 0x4c, 0x49, 0x41, 0x49, 0x4b, 0x47, 0x45, 0x4b, 0x44, 0x43, 0x46, 0x3f, + 0x47, 0x47, 0x43, 0x4c, 0x49, 0x4c, 0x3d, 0x4d, 0x4b, 0x54, 0x4a, 0x4f, + 0x44, 0x4c, 0x4b, 0x47, 0x4c, 0x45, 0x3d, 0x52, 0x58, 0x4b, 0x45, 0x4e, + 0x48, 0x39, 0x53, 0x70, 0x4a, 0x5d, 0x4c, 0x4e, 0x5a, 0x4f, 0x46, 0x4b, + 0x3e, 0x4f, 0x44, 0x3d, 0x66, 0x6b, 0x50, 0x4d, 0x4d, 0x57, 0x52, 0x4a, + 0x4c, 0x5b, 0x4e, 0x53, 0x4d, 0x54, 0x50, 0x42, 0x3c, 0x5d, 0x4a, 0x4c, + 0x56, 0x52, 0x50, 0x40, 0x48, 0x4c, 0x4d, 0x49, 0x49, 0x4f, 0x51, 0x38, + 0x42, 0x49, 0x4d, 0x4f, 0x45, 0x40, 0x4d, 0x41, 0x4b, 0x4a, 0x47, 0x51, + 0x4b, 0x53, 0x4c, 0x4a, 0x51, 0x4c, 0x42, 0x56, 0x48, 0x4a, 0x47, 0x58, + 0x49, 0x46, 0x52, 0x4a, 0x45, 0x47, 0x51, 0x54, 0x4f, 0x50, 0x50, 0x53, + 0x49, 0x4a, 0x4d, 0x56, 0x56, 0x4b, 0x4d, 0x45, 0x40, 0x4d, 0x48, 0x60, + 0x4e, 0x56, 0x48, 0x4b, 0x47, 0x45, 0x47, 0x62, 0x4e, 0x4f, 0x41, 0x49, + 0x48, 0x57, 0x44, 0x64, 0x4f, 0x4f, 0x49, 0x44, 0x49, 0x4c, 0x3f, 0x53, + 0x40, 0x41, 0x4e, 0x4b, 0x4d, 0x54, 0x42, 0x53, 0x4e, 0x41, 0x49, 0x44, + 0x41, 0x45, 0x4d, 0x4f, 0x47, 0x51, 0x45, 0x4a, 0x42, 0x45, 0x4e, 0x40, + 0x4b, 0x52, 0x48, 0x47, 0x4e, 0x4f, 0x47, 0x41, 0x48, 0x53, 0x47, 0x47, + 0x46, 0x42, 0x48, 0x4b, 0x42, 0x4c, 0x49, 0x4c, 0x45, 0x4c, 0x54, 0x45, + 0x4c, 0x43, 0x4e, 0x49, 0x56, 0x47, 0x45, 0x4f, 0x4d, 0x3a, 0x58, 0x74, + 0x49, 0x5b, 0x4c, 0x4f, 0x64, 0x4e, 0x45, 0x43, 0x44, 0x5b, 0x43, 0x41, + 0x63, 0x70, 0x55, 0x45, 0x4a, 0x4a, 0x4d, 0x51, 0x4b, 0x5a, 0x51, 0x57, + 0x54, 0x5b, 0x55, 0x44, 0x38, 0x57, 0x4e, 0x50, 0x4e, 0x56, 0x57, 0x3a, + 0x3a, 0x4b, 0x57, 0x4c, 0x51, 0x53, 0x4d, 0x3b, 0x44, 0x43, 0x47, 0x4c, + 0x48, 0x59, 0x51, 0x41, 0x43, 0x44, 0x51, 0x51, 0x4a, 0x54, 0x51, 0x4b, + 0x4e, 0x45, 0x51, 0x4a, 0x49, 0x4a, 0x4f, 0x52, 0x4c, 0x3e, 0x4e, 0x55, + 0x42, 0x46, 0x46, 0x4a, 0x42, 0x52, 0x49, 0x47, 0x4a, 0x56, 0x4f, 0x50, + 0x46, 0x4f, 0x43, 0x51, 0x53, 0x46, 0x40, 0x60, 0x44, 0x4d, 0x46, 0x54, + 0x3d, 0x49, 0x43, 0x64, 0x45, 0x4d, 0x50, 0x49, 0x4f, 0x4d, 0x53, 0x60, + 0x4a, 0x52, 0x49, 0x47, 0x48, 0x5a, 0x48, 0x58, 0x4e, 0x4f, 0x43, 0x4f, + 0x50, 0x51, 0x41, 0x52, 0x4c, 0x4d, 0x45, 0x42, 0x41, 0x4c, 0x44, 0x54, + 0x4e, 0x4d, 0x4a, 0x47, 0x40, 0x4a, 0x3e, 0x47, 0x4c, 0x58, 0x46, 0x46, + 0x55, 0x4c, 0x4d, 0x45, 0x49, 0x51, 0x53, 0x46, 0x46, 0x43, 0x43, 0x48, + 0x52, 0x3d, 0x4b, 0x4e, 0x49, 0x47, 0x3f, 0x3d, 0x4f, 0x45, 0x44, 0x3f, + 0x5a, 0x43, 0x4b, 0x4d, 0x51, 0x35, 0x54, 0x76, 0x4f, 0x5e, 0x4c, 0x50, + 0x5a, 0x51, 0x46, 0x49, 0x44, 0x61, 0x4f, 0x41, 0x67, 0x72, 0x56, 0x4f, + 0x42, 0x48, 0x4b, 0x52, 0x46, 0x60, 0x50, 0x4e, 0x4a, 0x5b, 0x5f, 0x46, + 0x31, 0x5b, 0x4a, 0x48, 0x4b, 0x58, 0x51, 0x41, 0x37, 0x4e, 0x4f, 0x55, + 0x51, 0x5c, 0x4f, 0x42, 0x4b, 0x4e, 0x4f, 0x54, 0x4f, 0x52, 0x43, 0x43, + 0x48, 0x53, 0x53, 0x41, 0x4b, 0x49, 0x4e, 0x50, 0x46, 0x4c, 0x4f, 0x49, + 0x42, 0x49, 0x4c, 0x4c, 0x4c, 0x41, 0x4e, 0x48, 0x47, 0x4c, 0x49, 0x53, + 0x44, 0x46, 0x51, 0x53, 0x45, 0x52, 0x4e, 0x53, 0x50, 0x58, 0x42, 0x45, + 0x44, 0x42, 0x48, 0x58, 0x4e, 0x4d, 0x54, 0x56, 0x4c, 0x46, 0x4a, 0x58, + 0x48, 0x4f, 0x47, 0x51, 0x47, 0x4f, 0x4f, 0x5b, 0x41, 0x4e, 0x45, 0x45, + 0x4a, 0x50, 0x3e, 0x57, 0x48, 0x4e, 0x41, 0x4c, 0x45, 0x51, 0x46, 0x4c, + 0x46, 0x4f, 0x42, 0x45, 0x4b, 0x4c, 0x49, 0x4c, 0x44, 0x4f, 0x4e, 0x4d, + 0x48, 0x56, 0x43, 0x48, 0x42, 0x54, 0x48, 0x43, 0x3e, 0x51, 0x43, 0x47, + 0x47, 0x47, 0x49, 0x4d, 0x46, 0x4e, 0x52, 0x42, 0x48, 0x4e, 0x4c, 0x4a, + 0x4d, 0x3e, 0x43, 0x40, 0x48, 0x41, 0x47, 0x4f, 0x5e, 0x49, 0x40, 0x4c, + 0x50, 0x42, 0x56, 0x75, 0x51, 0x5e, 0x51, 0x4e, 0x62, 0x58, 0x49, 0x47, + 0x51, 0x59, 0x46, 0x46, 0x6c, 0x72, 0x55, 0x44, 0x4c, 0x4a, 0x4d, 0x59, + 0x53, 0x64, 0x4d, 0x51, 0x55, 0x5e, 0x59, 0x50, 0x30, 0x58, 0x50, 0x4c, + 0x4c, 0x60, 0x59, 0x42, 0x32, 0x53, 0x50, 0x55, 0x4d, 0x53, 0x59, 0x43, + 0x3e, 0x49, 0x4f, 0x52, 0x4d, 0x51, 0x47, 0x45, 0x4d, 0x4e, 0x53, 0x4e, + 0x54, 0x4f, 0x4d, 0x4d, 0x4e, 0x40, 0x47, 0x53, 0x53, 0x49, 0x56, 0x4d, + 0x4d, 0x3a, 0x4c, 0x4e, 0x45, 0x4a, 0x47, 0x45, 0x53, 0x4a, 0x4e, 0x52, + 0x4d, 0x4e, 0x48, 0x56, 0x4e, 0x4a, 0x4d, 0x52, 0x49, 0x4e, 0x4e, 0x58, + 0x47, 0x50, 0x4c, 0x54, 0x49, 0x42, 0x46, 0x54, 0x50, 0x54, 0x54, 0x46, + 0x40, 0x49, 0x4b, 0x57, 0x4b, 0x59, 0x44, 0x46, 0x52, 0x55, 0x51, 0x55, + 0x4f, 0x50, 0x4d, 0x4d, 0x48, 0x50, 0x4e, 0x49, 0x4e, 0x42, 0x45, 0x3f, + 0x4d, 0x4f, 0x51, 0x47, 0x4a, 0x4c, 0x4b, 0x4b, 0x46, 0x4d, 0x44, 0x52, + 0x4d, 0x44, 0x40, 0x4d, 0x54, 0x46, 0x54, 0x44, 0x4b, 0x46, 0x47, 0x45, + 0x50, 0x45, 0x45, 0x4b, 0x4c, 0x48, 0x3f, 0x55, 0x4a, 0x45, 0x49, 0x4e, + 0x40, 0x49, 0x4a, 0x41, 0x56, 0x4b, 0x49, 0x4e, 0x4a, 0x41, 0x50, 0x70, + 0x56, 0x59, 0x4b, 0x55, 0x58, 0x59, 0x49, 0x47, 0x4a, 0x5a, 0x4c, 0x46, + 0x62, 0x7b, 0x58, 0x51, 0x44, 0x47, 0x44, 0x57, 0x4f, 0x65, 0x4e, 0x50, + 0x4d, 0x67, 0x5c, 0x4a, 0x2b, 0x61, 0x48, 0x4b, 0x4b, 0x5d, 0x5c, 0x48, + 0x39, 0x50, 0x45, 0x4d, 0x53, 0x60, 0x53, 0x46, 0x42, 0x46, 0x50, 0x45, + 0x4f, 0x4e, 0x46, 0x4a, 0x4d, 0x51, 0x54, 0x47, 0x59, 0x4b, 0x58, 0x4a, + 0x50, 0x3d, 0x59, 0x48, 0x45, 0x4e, 0x4e, 0x47, 0x4f, 0x47, 0x4d, 0x4b, + 0x52, 0x42, 0x4c, 0x48, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x4c, 0x4d, 0x51, + 0x49, 0x4f, 0x4c, 0x47, 0x47, 0x48, 0x47, 0x59, 0x4f, 0x4f, 0x53, 0x49, + 0x4e, 0x4b, 0x4f, 0x5a, 0x50, 0x42, 0x47, 0x50, 0x4a, 0x54, 0x47, 0x5a, + 0x43, 0x49, 0x47, 0x4e, 0x49, 0x4d, 0x43, 0x54, 0x4c, 0x53, 0x4e, 0x4e, + 0x42, 0x43, 0x48, 0x46, 0x4f, 0x43, 0x43, 0x45, 0x51, 0x47, 0x4b, 0x4f, + 0x56, 0x48, 0x48, 0x49, 0x46, 0x45, 0x4d, 0x52, 0x47, 0x4b, 0x46, 0x50, + 0x3e, 0x4e, 0x4c, 0x43, 0x45, 0x4d, 0x53, 0x43, 0x46, 0x45, 0x44, 0x52, + 0x45, 0x49, 0x49, 0x51, 0x3d, 0x4a, 0x4d, 0x46, 0x42, 0x41, 0x4e, 0x48, + 0x5a, 0x49, 0x49, 0x49, 0x4f, 0x3d, 0x56, 0x68, 0x56, 0x67, 0x4b, 0x57, + 0x5f, 0x5c, 0x40, 0x4a, 0x4a, 0x54, 0x4c, 0x47, 0x64, 0x7a, 0x54, 0x48, + 0x46, 0x45, 0x46, 0x57, 0x4e, 0x61, 0x4f, 0x50, 0x4d, 0x64, 0x5b, 0x43, + 0x2d, 0x60, 0x55, 0x51, 0x4c, 0x54, 0x4f, 0x4e, 0x2f, 0x50, 0x4f, 0x52, + 0x50, 0x61, 0x54, 0x4b, 0x3d, 0x4c, 0x47, 0x51, 0x4a, 0x54, 0x4b, 0x42, + 0x3b, 0x55, 0x47, 0x50, 0x4f, 0x49, 0x4a, 0x46, 0x43, 0x44, 0x45, 0x47, + 0x46, 0x4b, 0x4f, 0x46, 0x43, 0x47, 0x4a, 0x4e, 0x51, 0x43, 0x55, 0x47, + 0x4d, 0x46, 0x4c, 0x4c, 0x49, 0x4d, 0x43, 0x51, 0x47, 0x51, 0x52, 0x4a, + 0x46, 0x4f, 0x49, 0x52, 0x50, 0x4a, 0x43, 0x53, 0x46, 0x4e, 0x50, 0x54, + 0x45, 0x3a, 0x4a, 0x4a, 0x4c, 0x50, 0x4b, 0x54, 0x43, 0x4f, 0x4e, 0x45, + 0x49, 0x4f, 0x46, 0x53, 0x4d, 0x51, 0x52, 0x53, 0x3d, 0x4a, 0x47, 0x4e, + 0x43, 0x4a, 0x53, 0x48, 0x4a, 0x4c, 0x4a, 0x4a, 0x42, 0x53, 0x3e, 0x43, + 0x4f, 0x4c, 0x47, 0x48, 0x54, 0x4d, 0x48, 0x48, 0x4e, 0x4c, 0x43, 0x51, + 0x42, 0x49, 0x44, 0x3e, 0x49, 0x51, 0x4a, 0x4d, 0x4f, 0x49, 0x45, 0x44, + 0x4e, 0x41, 0x48, 0x4b, 0x4c, 0x49, 0x46, 0x47, 0x5d, 0x4c, 0x4d, 0x50, + 0x45, 0x40, 0x4e, 0x6a, 0x4f, 0x62, 0x53, 0x50, 0x5c, 0x5e, 0x4a, 0x4c, + 0x50, 0x56, 0x52, 0x42, 0x60, 0x7e, 0x5b, 0x4b, 0x43, 0x41, 0x4c, 0x56, + 0x46, 0x5f, 0x4d, 0x49, 0x43, 0x65, 0x5c, 0x4d, 0x2c, 0x61, 0x48, 0x4c, + 0x44, 0x55, 0x5c, 0x49, 0x37, 0x54, 0x4e, 0x57, 0x52, 0x5c, 0x50, 0x49, + 0x3e, 0x4d, 0x4f, 0x4f, 0x51, 0x4c, 0x48, 0x43, 0x4a, 0x5a, 0x4d, 0x4b, + 0x4e, 0x58, 0x54, 0x49, 0x51, 0x42, 0x49, 0x4f, 0x46, 0x45, 0x52, 0x3d, + 0x4b, 0x4b, 0x43, 0x54, 0x47, 0x47, 0x4c, 0x42, 0x4b, 0x49, 0x45, 0x46, + 0x46, 0x4a, 0x51, 0x47, 0x47, 0x4f, 0x48, 0x4a, 0x3f, 0x4c, 0x4b, 0x57, + 0x4a, 0x3f, 0x52, 0x4a, 0x56, 0x52, 0x4b, 0x54, 0x4c, 0x3e, 0x3f, 0x4f, + 0x4b, 0x50, 0x4c, 0x53, 0x4a, 0x49, 0x46, 0x4e, 0x50, 0x48, 0x4f, 0x4b, + 0x4a, 0x4e, 0x3e, 0x49, 0x45, 0x42, 0x42, 0x41, 0x47, 0x4b, 0x4f, 0x42, + 0x49, 0x4c, 0x55, 0x4c, 0x4e, 0x42, 0x47, 0x42, 0x4b, 0x48, 0x46, 0x41, + 0x46, 0x4e, 0x4d, 0x3f, 0x4f, 0x46, 0x4f, 0x4b, 0x4b, 0x4d, 0x50, 0x3e, + 0x42, 0x43, 0x44, 0x4a, 0x49, 0x40, 0x4e, 0x43, 0x3e, 0x52, 0x3e, 0x44, + 0x49, 0x43, 0x4d, 0x44, 0x62, 0x51, 0x42, 0x53, 0x51, 0x40, 0x4c, 0x64, + 0x4f, 0x63, 0x4e, 0x5c, 0x5b, 0x5c, 0x48, 0x4d, 0x4a, 0x57, 0x4f, 0x42, + 0x65, 0xfe, 0x5c, 0x4e, 0x47, 0x43, 0x4a, 0x58, 0x4e, 0x5e, 0x48, 0x4c, + 0x51, 0x5e, 0x60, 0x56, 0x2f, 0x62, 0x54, 0x58, 0x51, 0x52, 0x55, 0x51, + 0x36, 0x4b, 0x46, 0x51, 0x53, 0x5f, 0x46, 0x4c, 0x37, 0x4d, 0x4a, 0x45, + 0x4b, 0x3f, 0x41, 0x42, 0x3f, 0x53, 0x4a, 0x48, 0x49, 0x4a, 0x4a, 0x45, + 0x52, 0x3f, 0x52, 0x52, 0x45, 0x4d, 0x4f, 0x45, 0x46, 0x4a, 0x51, 0x48, + 0x56, 0x47, 0x50, 0x3e, 0x46, 0x49, 0x4c, 0x51, 0x49, 0x54, 0x45, 0x4f, + 0x4b, 0x4b, 0x49, 0x46, 0x4b, 0x4d, 0x49, 0x5c, 0x4d, 0x43, 0x47, 0x49, + 0x48, 0x52, 0x46, 0x50, 0x51, 0x37, 0x50, 0x52, 0x4c, 0x4d, 0x4f, 0x51, + 0x4f, 0x42, 0x50, 0x47, 0x48, 0x4e, 0x4d, 0x4c, 0x48, 0x48, 0x4a, 0x51, + 0x49, 0x42, 0x50, 0x4f, 0x43, 0x4e, 0x47, 0x4b, 0x47, 0x4a, 0x44, 0x44, + 0x4c, 0x51, 0x49, 0x44, 0x45, 0x45, 0x45, 0x48, 0x3f, 0x4a, 0x43, 0x49, + 0x46, 0x49, 0x4c, 0x4d, 0x45, 0x50, 0x44, 0x45, 0x44, 0x55, 0x4a, 0x45, + 0x48, 0x47, 0x4c, 0x43, 0x3f, 0x48, 0x42, 0x43, 0x43, 0x43, 0x48, 0x46, + 0x5c, 0x51, 0x47, 0x51, 0x48, 0x40, 0x54, 0x66, 0x4e, 0x67, 0x4d, 0x5a, + 0x60, 0x57, 0x47, 0x4d, 0x4d, 0x58, 0x53, 0x46, 0x66, 0x7e, 0x56, 0x48, + 0x44, 0x4f, 0x49, 0x5c, 0x4a, 0x63, 0x50, 0x4c, 0x49, 0x56, 0x61, 0x50, + 0x2c, 0x68, 0x4d, 0x51, 0x46, 0x4e, 0x5b, 0x51, 0x2e, 0x53, 0x54, 0x50, + 0x46, 0x58, 0x44, 0x4f, 0x37, 0x48, 0x55, 0x50, 0x49, 0x49, 0x4e, 0x46, + 0x43, 0x56, 0x52, 0x4e, 0x50, 0x4b, 0x50, 0x4c, 0x49, 0x40, 0x4d, 0x4f, + 0x50, 0x41, 0x44, 0x39, 0x4b, 0x4d, 0x4b, 0x41, 0x51, 0x4d, 0x4c, 0x41, + 0x3f, 0x52, 0x4e, 0x4b, 0x49, 0x53, 0x45, 0x43, 0x4d, 0x4f, 0x44, 0x4d, + 0x4b, 0x53, 0x50, 0x4e, 0x45, 0x3f, 0x4e, 0x51, 0x50, 0x55, 0x4f, 0x51, + 0x4d, 0x3d, 0x58, 0x3f, 0x46, 0x50, 0x50, 0x50, 0x56, 0x42, 0x49, 0x49, + 0x50, 0x4f, 0x42, 0x4b, 0x4c, 0x45, 0x52, 0x41, 0x46, 0x43, 0x4c, 0x4a, + 0x4c, 0x51, 0x4d, 0x4d, 0x4a, 0x49, 0x54, 0x49, 0x58, 0x53, 0x49, 0x45, + 0x47, 0x4c, 0x4c, 0x44, 0x4e, 0x51, 0x4c, 0x4c, 0x47, 0x48, 0x4c, 0x4e, + 0x49, 0x54, 0x4c, 0x51, 0x49, 0x48, 0x47, 0x45, 0x42, 0x49, 0x42, 0x51, + 0x4e, 0x3f, 0x49, 0x41, 0x50, 0x3e, 0x4d, 0x50, 0x5c, 0x51, 0x4d, 0x56, + 0x47, 0x48, 0x58, 0x65, 0x51, 0x6b, 0x56, 0x5b, 0x56, 0x55, 0x46, 0x49, + 0x4b, 0x58, 0x59, 0x4a, 0x68, 0x79, 0x53, 0x46, 0x45, 0x4b, 0x53, 0x5d, + 0x4b, 0x6f, 0x4e, 0x4f, 0x4c, 0x53, 0x5b, 0x52, 0x30, 0x63, 0x46, 0x57, + 0x46, 0x50, 0x4b, 0x48, 0x2e, 0x4c, 0x46, 0x48, 0x44, 0x51, 0x46, 0x4a, + 0x35, 0x55, 0x43, 0x4c, 0x43, 0x4d, 0x4e, 0x3e, 0x47, 0x56, 0x50, 0x4d, + 0x44, 0x59, 0x4c, 0x51, 0x46, 0x42, 0x4e, 0x43, 0x4c, 0x44, 0x42, 0x3a, + 0x40, 0x48, 0x46, 0x44, 0x45, 0x4a, 0x46, 0x3a, 0x53, 0x4c, 0x4d, 0x4c, + 0x4a, 0x4f, 0x53, 0x40, 0x4b, 0x48, 0x54, 0x4b, 0x44, 0x59, 0x41, 0x50, + 0x4e, 0x50, 0x55, 0x4d, 0x55, 0x41, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x50, + 0x52, 0x4c, 0x50, 0x4d, 0x47, 0x42, 0x4f, 0x4b, 0x47, 0x43, 0x41, 0x4a, + 0x55, 0x3e, 0x50, 0x4b, 0x41, 0x49, 0x47, 0x49, 0x53, 0x4d, 0x48, 0x4b, + 0x43, 0x43, 0x51, 0x44, 0x4d, 0x4c, 0x44, 0x50, 0x4d, 0x42, 0x49, 0x4e, + 0x50, 0x50, 0x4c, 0x49, 0x49, 0x51, 0x46, 0x43, 0x4a, 0x4e, 0x53, 0x47, + 0x43, 0x46, 0x40, 0x49, 0x47, 0x44, 0x44, 0x4d, 0x4b, 0x4b, 0x51, 0x4b, + 0x45, 0x49, 0x47, 0x43, 0x56, 0x49, 0x4c, 0x54, 0x50, 0x3c, 0x4c, 0x5e, + 0x51, 0x67, 0x4f, 0x57, 0x57, 0x53, 0x3e, 0x4e, 0x4e, 0x5e, 0x4b, 0x48, + 0x5a, 0x78, 0x55, 0x4a, 0x3f, 0x4b, 0x4c, 0x5b, 0x53, 0x64, 0x4d, 0x53, + 0x49, 0x57, 0x57, 0x58, 0x37, 0x62, 0x4f, 0x56, 0x44, 0x4e, 0x58, 0x4a, + 0x30, 0x4f, 0x40, 0x4e, 0x47, 0x58, 0x52, 0x50, 0x35, 0x4d, 0x49, 0x52, + 0x4e, 0x42, 0x46, 0x47, 0x44, 0x57, 0x54, 0x43, 0x4e, 0x56, 0x43, 0x49, + 0x44, 0x40, 0x44, 0x41, 0x50, 0x49, 0x4b, 0x44, 0x4d, 0x52, 0x49, 0x43, + 0x52, 0x54, 0x49, 0x3f, 0x49, 0x42, 0x49, 0x4a, 0x43, 0x3e, 0x50, 0x40, + 0x46, 0x4b, 0x50, 0x4b, 0x53, 0x4b, 0x47, 0x52, 0x51, 0x4b, 0x47, 0x3f, + 0x46, 0x4b, 0x4c, 0x57, 0x49, 0x47, 0x54, 0x49, 0x50, 0x50, 0x4d, 0x4a, + 0x42, 0x4e, 0x51, 0x4c, 0x47, 0x47, 0x42, 0x43, 0x54, 0x43, 0x46, 0x47, + 0x4d, 0x43, 0x54, 0x47, 0x43, 0x58, 0x48, 0x45, 0x4b, 0x46, 0x48, 0x3d, + 0x47, 0x3f, 0x44, 0x4f, 0x4e, 0x46, 0x41, 0x40, 0x4d, 0x4d, 0x4d, 0x52, + 0x54, 0x47, 0x4f, 0x51, 0x4f, 0x45, 0x45, 0x48, 0x4b, 0x4d, 0x44, 0x52, + 0x51, 0x4b, 0x48, 0x4f, 0x49, 0x49, 0x46, 0x50, 0x54, 0x42, 0x44, 0x51, + 0x58, 0x4e, 0x43, 0x58, 0x55, 0x40, 0x53, 0x5a, 0x51, 0x61, 0x51, 0x60, + 0x53, 0x57, 0x45, 0x4f, 0x45, 0x5e, 0x51, 0x42, 0x61, 0x7a, 0x55, 0x47, + 0x41, 0x4b, 0x4a, 0x5b, 0x4c, 0x65, 0x4f, 0x55, 0x46, 0x54, 0x65, 0x59, + 0x36, 0x61, 0x54, 0x55, 0x48, 0x57, 0x52, 0x4e, 0x24, 0x4b, 0x49, 0x4d, + 0x43, 0x57, 0x44, 0x51, 0x3b, 0x4f, 0x45, 0x40, 0x47, 0x4a, 0x43, 0x47, + 0x46, 0x58, 0x50, 0x54, 0x4d, 0x50, 0x44, 0x42, 0x4a, 0x46, 0x4b, 0x4d, + 0x4f, 0x4f, 0x4d, 0x40, 0x48, 0x4a, 0x53, 0x48, 0x49, 0x48, 0x4d, 0x39, + 0x47, 0x4e, 0x44, 0x4c, 0x4b, 0x49, 0x44, 0x42, 0x4a, 0x45, 0x46, 0x46, + 0x53, 0x4d, 0x49, 0x4f, 0x4e, 0x48, 0x50, 0x4a, 0x4c, 0x46, 0x56, 0x4b, + 0x4b, 0x57, 0x4c, 0x49, 0x4a, 0x4a, 0x43, 0x4e, 0x56, 0x45, 0x50, 0x4c, + 0x47, 0x55, 0x48, 0x46, 0x4e, 0x46, 0x45, 0x3f, 0x4a, 0x4c, 0x4c, 0x47, + 0x4a, 0x51, 0x4e, 0x50, 0x40, 0x52, 0x45, 0x45, 0x4b, 0x46, 0x4f, 0x44, + 0x51, 0x4a, 0x4e, 0x4d, 0x4c, 0x46, 0x42, 0x47, 0x4a, 0x4e, 0x46, 0x42, + 0x4b, 0x4f, 0x4b, 0x4e, 0x4e, 0x46, 0x42, 0x50, 0x53, 0x51, 0x4f, 0x54, + 0x45, 0x4f, 0x45, 0x42, 0x4c, 0x45, 0x40, 0x48, 0x59, 0x49, 0x49, 0x53, + 0x4c, 0x43, 0x4b, 0x57, 0x54, 0x64, 0x4e, 0x5f, 0x5c, 0x59, 0x4b, 0x56, + 0x49, 0x5d, 0x4f, 0x4b, 0x62, 0x73, 0x54, 0x45, 0x49, 0x50, 0x48, 0x5a, + 0x50, 0x6d, 0x4a, 0x4e, 0x48, 0x55, 0x5d, 0x57, 0x38, 0x68, 0x52, 0x5a, + 0x46, 0x56, 0x4c, 0x5a, 0x2e, 0x55, 0x49, 0x4f, 0x4a, 0x57, 0x4f, 0x54, + 0x41, 0x53, 0x46, 0x43, 0x45, 0x47, 0x53, 0x4a, 0x42, 0x4f, 0x4d, 0x48, + 0x4c, 0x49, 0x47, 0x48, 0x45, 0x49, 0x48, 0x53, 0x48, 0x52, 0x4a, 0x44, + 0x4c, 0x49, 0x52, 0x4b, 0x47, 0x51, 0x42, 0x47, 0x49, 0x51, 0x3f, 0x45, + 0x47, 0x4e, 0x53, 0x33, 0x55, 0x51, 0x55, 0x48, 0x4b, 0x51, 0x56, 0x47, + 0x43, 0x55, 0x47, 0x42, 0x47, 0x4f, 0x47, 0x51, 0x46, 0x55, 0x4a, 0x4b, + 0x50, 0x52, 0x4f, 0x43, 0x4b, 0x53, 0x4d, 0x3f, 0x4e, 0x56, 0x50, 0x49, + 0x4d, 0x47, 0x51, 0x49, 0x4a, 0x52, 0x44, 0x43, 0x4d, 0x4e, 0x41, 0x51, + 0x4c, 0x4d, 0x47, 0x48, 0x4f, 0x40, 0x50, 0x46, 0x43, 0x4d, 0x4e, 0x50, + 0x43, 0x47, 0x4e, 0x46, 0x4f, 0x4b, 0x51, 0x4b, 0x4a, 0x57, 0x42, 0x51, + 0x4c, 0x54, 0x52, 0x42, 0x4c, 0x42, 0x47, 0x54, 0x4a, 0x4a, 0x47, 0x4a, + 0x3f, 0x46, 0x4e, 0x4c, 0x53, 0x50, 0x47, 0x53, 0x49, 0x44, 0x52, 0x5a, + 0x4b, 0x65, 0x50, 0x5b, 0x57, 0x59, 0x4a, 0x48, 0x48, 0x5f, 0x55, 0x48, + 0x5c, 0x78, 0x55, 0x48, 0x4a, 0x4b, 0x49, 0x4c, 0x46, 0x6b, 0x54, 0x57, + 0x55, 0x4b, 0x59, 0x52, 0x38, 0x5b, 0x57, 0x56, 0x4b, 0x4f, 0x48, 0x4e, + 0x34, 0x5a, 0x4e, 0x4f, 0x43, 0x4e, 0x4b, 0x4e, 0x36, 0x4d, 0x52, 0x48, + 0x4d, 0x4c, 0x4c, 0x49, 0x51, 0x54, 0x45, 0x54, 0x4a, 0x4e, 0x52, 0x41, + 0x4c, 0x45, 0x4a, 0x53, 0x55, 0x4b, 0x50, 0x47, 0x4e, 0x4d, 0x43, 0x51, + 0x4e, 0x4a, 0x51, 0x46, 0x4e, 0x4d, 0x48, 0x3f, 0x43, 0x52, 0x56, 0x38, + 0x52, 0x46, 0x43, 0x49, 0x40, 0x49, 0x53, 0x41, 0x47, 0x41, 0x41, 0x42, + 0x4f, 0x4b, 0x46, 0x4b, 0x4a, 0x57, 0x4a, 0x45, 0x4b, 0x46, 0x47, 0x3c, + 0x43, 0x46, 0x4f, 0x50, 0x4c, 0x53, 0x4f, 0x41, 0x4a, 0x4a, 0x40, 0x4a, + 0x3e, 0x4e, 0x4d, 0x41, 0x4a, 0x42, 0x49, 0x4c, 0x51, 0x46, 0x4f, 0x43, + 0x4b, 0x41, 0x50, 0x48, 0x4a, 0x40, 0x52, 0x45, 0x40, 0x40, 0x46, 0x48, + 0x48, 0x52, 0x52, 0x41, 0x43, 0x49, 0x49, 0x4c, 0x44, 0x48, 0x50, 0x4a, + 0x47, 0x48, 0x4c, 0x42, 0x49, 0x48, 0x52, 0x56, 0x4b, 0x41, 0x4e, 0x47, + 0x52, 0x56, 0x4e, 0x56, 0x4b, 0x38, 0x50, 0x55, 0x5a, 0x63, 0x51, 0x5a, + 0x54, 0x52, 0x44, 0x45, 0x47, 0x5e, 0x4c, 0x4a, 0x5e, 0x71, 0x56, 0x44, + 0x4c, 0x4b, 0x4c, 0x4e, 0x49, 0x69, 0x50, 0x53, 0x4d, 0x5c, 0x59, 0x50, + 0x36, 0x5d, 0x46, 0x5b, 0x51, 0x55, 0x55, 0x51, 0x36, 0x5a, 0x53, 0x56, + 0x54, 0x4a, 0x55, 0x53, 0x3c, 0x52, 0x4a, 0x45, 0x4c, 0x56, 0x49, 0x46, + 0x4f, 0x5b, 0x43, 0x4b, 0x49, 0x4c, 0x4b, 0x41, 0x44, 0x4b, 0x47, 0x4b, + 0x4b, 0x54, 0x4a, 0x4c, 0x49, 0x44, 0x46, 0x46, 0x48, 0x49, 0x47, 0x4a, + 0x40, 0x4e, 0x47, 0x53, 0x4a, 0x47, 0x4a, 0x3b, 0x48, 0x4b, 0x50, 0x51, + 0x50, 0x44, 0x4d, 0x49, 0x42, 0x4b, 0x43, 0x48, 0x4a, 0x43, 0x4d, 0x4d, + 0x49, 0x4d, 0x43, 0x4f, 0x50, 0x49, 0x47, 0x48, 0x48, 0x4f, 0x49, 0x41, + 0x4c, 0x46, 0x47, 0x3e, 0x51, 0x4d, 0x4e, 0x42, 0x3d, 0x53, 0x4d, 0x3b, + 0x53, 0x52, 0x4c, 0x4c, 0x43, 0x46, 0x43, 0x3d, 0x53, 0x48, 0x43, 0x4e, + 0x45, 0x52, 0x4d, 0x4a, 0x44, 0x49, 0x47, 0x4c, 0x4e, 0x4c, 0x4a, 0x4e, + 0x41, 0x48, 0x4b, 0x44, 0x4d, 0x4a, 0x4d, 0x44, 0x4a, 0x45, 0x4f, 0x52, + 0x45, 0x3f, 0x4b, 0x48, 0x43, 0x41, 0x3d, 0x53, 0x53, 0x50, 0x4a, 0x56, + 0x4d, 0x3e, 0x55, 0x4e, 0x56, 0x5e, 0x52, 0x52, 0x54, 0x50, 0x42, 0x4a, + 0x4d, 0x5f, 0x4f, 0x49, 0x5d, 0x6f, 0x55, 0x4a, 0x47, 0x49, 0x4e, 0x4a, + 0x43, 0x6e, 0x4e, 0x4f, 0x52, 0x59, 0x62, 0x4b, 0x3e, 0x5c, 0x4c, 0x4e, + 0x45, 0x52, 0x43, 0x4d, 0x3c, 0x58, 0x52, 0x49, 0x48, 0x55, 0x53, 0x4e, + 0x3d, 0x4e, 0x4c, 0x4b, 0x4b, 0x50, 0x4a, 0x47, 0x45, 0x62, 0x50, 0x49, + 0x48, 0x4b, 0x55, 0x45, 0x46, 0x51, 0x41, 0x55, 0x54, 0x55, 0x50, 0x47, + 0x46, 0x4d, 0x46, 0x4b, 0x41, 0x49, 0x4c, 0x40, 0x45, 0x4f, 0x52, 0x54, + 0x45, 0x4d, 0x53, 0x3a, 0x4c, 0x55, 0x4e, 0x48, 0x44, 0x45, 0x56, 0x3c, + 0x48, 0x46, 0x4b, 0x51, 0x53, 0x43, 0x41, 0x49, 0x4c, 0x52, 0x48, 0x42, + 0x48, 0x3f, 0x4c, 0x38, 0x46, 0x50, 0x4a, 0x44, 0x50, 0x54, 0x4e, 0x38, + 0x48, 0x42, 0x43, 0x4a, 0x4c, 0x44, 0x47, 0x42, 0x42, 0x46, 0x4a, 0x50, + 0x47, 0x4b, 0x43, 0x40, 0x44, 0x46, 0x46, 0x4d, 0x50, 0x4a, 0x4e, 0x51, + 0x44, 0x40, 0x50, 0x43, 0x52, 0x4d, 0x42, 0x4c, 0x50, 0x41, 0x4a, 0x4e, + 0x45, 0x49, 0x4d, 0x40, 0x46, 0x51, 0x43, 0x4b, 0x48, 0x47, 0x42, 0x55, + 0x4a, 0x41, 0x4f, 0x49, 0x4f, 0x4e, 0x47, 0x4c, 0x4a, 0x48, 0x50, 0x4e, + 0x50, 0x57, 0x4e, 0x56, 0x56, 0x4e, 0x44, 0x48, 0x4a, 0x5b, 0x55, 0x49, + 0x59, 0x67, 0x54, 0x46, 0x4f, 0x41, 0x4d, 0x4e, 0x4a, 0x63, 0x4d, 0x44, + 0x53, 0x5b, 0x59, 0x4f, 0x43, 0x55, 0x56, 0x4e, 0x55, 0x4c, 0x4b, 0x54, + 0x3c, 0x56, 0x4d, 0x50, 0x4f, 0x4a, 0x5a, 0x47, 0x48, 0x56, 0x4f, 0x4f, + 0x50, 0x51, 0x48, 0x4e, 0x4d, 0x50, 0x4e, 0x45, 0x4b, 0x48, 0x4e, 0x44, + 0x46, 0x4d, 0x43, 0x46, 0x41, 0x59, 0x53, 0x4b, 0x4a, 0x3e, 0x51, 0x47, + 0x43, 0x48, 0x52, 0x3f, 0x43, 0x50, 0x4b, 0x4f, 0x41, 0x48, 0x43, 0x2e, + 0x4d, 0x4e, 0x4c, 0x45, 0x45, 0x46, 0x4b, 0x43, 0x46, 0x49, 0x46, 0x4d, + 0x47, 0x4e, 0x4d, 0x3c, 0x47, 0x4a, 0x52, 0x4e, 0x41, 0x50, 0x43, 0x3a, + 0x50, 0x47, 0x4a, 0x45, 0x52, 0x4a, 0x4c, 0x3f, 0x42, 0x3d, 0x49, 0x48, + 0x48, 0x4c, 0x42, 0x3a, 0x40, 0x47, 0x46, 0x4e, 0x44, 0x52, 0x46, 0x44, + 0x4a, 0x44, 0x43, 0x49, 0x42, 0x45, 0x3f, 0x50, 0x4c, 0x44, 0x48, 0x43, + 0x47, 0x4a, 0x48, 0x48, 0x3e, 0x45, 0x43, 0x48, 0x4a, 0x48, 0x53, 0x4b, + 0x50, 0x49, 0x43, 0x4d, 0x53, 0x4f, 0x4b, 0x4b, 0x40, 0x42, 0x50, 0x4d, + 0x53, 0x4e, 0x44, 0x4d, 0x45, 0x3d, 0x51, 0x51, 0x4f, 0x59, 0x4b, 0x51, + 0x4a, 0x4e, 0x42, 0x40, 0x49, 0x5b, 0x4b, 0x43, 0x53, 0x60, 0x47, 0x49, + 0x4a, 0x44, 0x44, 0x48, 0x4b, 0x60, 0x51, 0x3f, 0x4b, 0x5b, 0x4f, 0x4a, + 0x4a, 0x50, 0x49, 0x46, 0x55, 0x50, 0x4b, 0x4c, 0x40, 0x4e, 0x51, 0x4f, + 0x4b, 0x51, 0x54, 0x50, 0x48, 0x4e, 0x4a, 0x4f, 0x4d, 0x4e, 0x54, 0x4d, + 0x41, 0x50, 0x4e, 0x47, 0x47, 0x47, 0x54, 0x3b, 0x51, 0x54, 0x50, 0x49, + 0x48, 0x4c, 0x4e, 0x47, 0x3f, 0x3c, 0x4c, 0x43, 0x45, 0x42, 0x45, 0x37, + 0x41, 0x52, 0x49, 0x47, 0x4e, 0x4a, 0x4b, 0x37, 0x48, 0x4d, 0x4e, 0x4a, + 0x42, 0x56, 0x3d, 0x35, 0x48, 0x42, 0x4b, 0x4a, 0x44, 0x52, 0x40, 0x48, + 0x4f, 0x49, 0x4f, 0x4c, 0x4d, 0x43, 0x49, 0x38, 0x4b, 0x42, 0x48, 0x42, + 0x45, 0x45, 0x54, 0x3a, 0x47, 0x47, 0x52, 0x45, 0x4a, 0x48, 0x47, 0x39, + 0x4d, 0x45, 0x54, 0x4b, 0x4e, 0x4f, 0x4e, 0x38, 0x4a, 0x4b, 0x48, 0x45, + 0x4e, 0x43, 0x4e, 0x4e, 0x46, 0x4e, 0x4e, 0x50, 0x46, 0x4c, 0x42, 0x45, + 0x4b, 0x46, 0x47, 0x4d, 0x49, 0x3f, 0x4f, 0x50, 0x46, 0x4a, 0x47, 0x4e, + 0x4a, 0x3e, 0x50, 0x46, 0x47, 0x40, 0x4f, 0x47, 0x51, 0x4b, 0x43, 0x46, + 0x4a, 0x42, 0x55, 0x4d, 0x46, 0x63, 0x49, 0x4e, 0x4f, 0x4f, 0x42, 0x45, + 0x50, 0x57, 0x49, 0x3e, 0x57, 0x63, 0x45, 0x4a, 0x49, 0x50, 0x41, 0x4a, + 0x48, 0x64, 0x4f, 0x42, 0x47, 0x58, 0x4b, 0x45, 0x43, 0x57, 0x49, 0x58, + 0x51, 0x51, 0x47, 0x43, 0x51, 0x4b, 0x4a, 0x45, 0x50, 0x54, 0x4d, 0x4d, + 0x3e, 0x4a, 0x50, 0x40, 0x51, 0x4f, 0x52, 0x48, 0x53, 0x49, 0x44, 0x4b, + 0x51, 0x4b, 0x50, 0x42, 0x4d, 0x49, 0x4a, 0x46, 0x44, 0x50, 0x47, 0x3f, + 0x48, 0x47, 0x41, 0x4a, 0x42, 0x52, 0x4a, 0x33, 0x50, 0x50, 0x54, 0x3f, + 0x44, 0x4e, 0x51, 0x3c, 0x4e, 0x51, 0x48, 0x4b, 0x47, 0x49, 0x3f, 0x3d, + 0x4e, 0x46, 0x4a, 0x41, 0x40, 0x50, 0x49, 0x40, 0x4a, 0x4b, 0x45, 0x50, + 0x4e, 0x4d, 0x4b, 0x39, 0x4e, 0x4b, 0x48, 0x3c, 0x47, 0x44, 0x4c, 0x42, + 0x45, 0x50, 0x3e, 0x54, 0x4d, 0x49, 0x48, 0x3c, 0x45, 0x42, 0x55, 0x4a, + 0x41, 0x4f, 0x40, 0x3f, 0x47, 0x46, 0x46, 0x44, 0x4f, 0x47, 0x46, 0x44, + 0x41, 0x40, 0x44, 0x48, 0x3e, 0x3c, 0x46, 0x3e, 0x4a, 0x45, 0x4c, 0x52, + 0x47, 0x42, 0x47, 0x3f, 0x47, 0x4e, 0x4b, 0x53, 0x4a, 0x3d, 0x4d, 0x47, + 0x4f, 0x3d, 0x4e, 0x43, 0x4f, 0x46, 0x43, 0x43, 0x46, 0x41, 0x4f, 0x42, + 0x46, 0x57, 0x4d, 0x51, 0x49, 0x51, 0x4c, 0x44, 0x51, 0x4f, 0x46, 0x44, + 0x54, 0x5d, 0x4f, 0x40, 0x59, 0x46, 0x53, 0x46, 0x48, 0x54, 0x43, 0x45, + 0x4d, 0x51, 0x4f, 0x44, 0x44, 0x53, 0x49, 0x4e, 0x48, 0x46, 0x44, 0x4a, + 0x4a, 0x42, 0x4c, 0x46, 0x54, 0x4f, 0x52, 0x47, 0x46, 0x44, 0x4c, 0x4d, + 0x4c, 0x47, 0x4d, 0x40, 0x55, 0x58, 0x46, 0x46, 0x3f, 0x3e, 0x47, 0x36, + 0x3f, 0x4d, 0x4b, 0x4d, 0x4f, 0x4f, 0x48, 0x34, 0x4d, 0x46, 0x46, 0x50, + 0x50, 0x4b, 0x47, 0x45, 0x4e, 0x49, 0x50, 0x4f, 0x4a, 0x48, 0x4f, 0x39, + 0x53, 0x4c, 0x4b, 0x56, 0x45, 0x4f, 0x55, 0x3a, 0x40, 0x53, 0x43, 0x4b, + 0x47, 0x3d, 0x4c, 0x34, 0x4b, 0x4e, 0x4a, 0x4b, 0x4d, 0x49, 0x4e, 0x40, + 0x4d, 0x48, 0x40, 0x4a, 0x4a, 0x4b, 0x4a, 0x42, 0x4c, 0x52, 0x43, 0x42, + 0x44, 0x3f, 0x4e, 0x42, 0x44, 0x45, 0x40, 0x3d, 0x4b, 0x45, 0x4a, 0x43, + 0x4b, 0x4b, 0x4e, 0x46, 0x55, 0x43, 0x44, 0x3f, 0x44, 0x43, 0x4b, 0x4b, + 0x45, 0x51, 0x48, 0x49, 0x3d, 0x44, 0x4a, 0x4a, 0x50, 0x50, 0x47, 0x44, + 0x4f, 0x3e, 0x3f, 0x43, 0x4c, 0x46, 0x4a, 0x4e, 0x4c, 0x52, 0x48, 0x4e, + 0x48, 0x46, 0x45, 0x48, 0x41, 0x4f, 0x51, 0x48, 0x40, 0x4d, 0x4a, 0x4b, + 0x4c, 0x51, 0x49, 0x50, 0x4e, 0x4b, 0x4a, 0x42, 0x49, 0x54, 0x4e, 0x43, + 0x52, 0x47, 0x4a, 0x41, 0x42, 0x51, 0x48, 0x4a, 0x46, 0x45, 0x4a, 0x43, + 0x4e, 0x4f, 0x41, 0x49, 0x4b, 0x42, 0x40, 0x4a, 0x50, 0x41, 0x42, 0x3f, + 0x49, 0x4a, 0x40, 0x3e, 0x3f, 0x42, 0x4d, 0x51, 0x4e, 0x4e, 0x47, 0x41, + 0x4e, 0x4e, 0x49, 0x4b, 0x41, 0x45, 0x51, 0x40, 0x45, 0x4c, 0x3f, 0x42, + 0x4c, 0x45, 0x4d, 0x39, 0x46, 0x52, 0x4a, 0x4e, 0x4c, 0x49, 0x4e, 0x43, + 0x43, 0x4c, 0x48, 0x46, 0x48, 0x49, 0x50, 0x3a, 0x3f, 0x49, 0x42, 0x4f, + 0x42, 0x4d, 0x4e, 0x3f, 0x51, 0x4b, 0x4e, 0x4b, 0x51, 0x44, 0x43, 0x4a, + 0x4a, 0x4c, 0x50, 0x48, 0x45, 0x47, 0x4d, 0x41, 0x47, 0x45, 0x51, 0x41, + 0x42, 0x48, 0x4c, 0x39, 0x51, 0x45, 0x46, 0x53, 0x4b, 0x50, 0x46, 0x45, + 0x4b, 0x4d, 0x42, 0x4b, 0x3f, 0x45, 0x4b, 0x4e, 0x50, 0x50, 0x47, 0x4a, + 0x45, 0x40, 0x4b, 0x43, 0x3f, 0x4a, 0x41, 0x42, 0x51, 0x41, 0x4d, 0x42, + 0x53, 0x48, 0x48, 0x49, 0x4b, 0x40, 0x42, 0x3d, 0x4f, 0x53, 0x49, 0x46, + 0x46, 0x43, 0x42, 0x44, 0x46, 0x48, 0x3f, 0x46, 0x31, 0x43, 0x4d, 0x4b, + 0x48, 0x4d, 0x4c, 0x43, 0x45, 0x53, 0x50, 0x40, 0x4a, 0x48, 0x45, 0x3b, + 0x4f, 0x4d, 0x53, 0x4c, 0x44, 0x54, 0x50, 0x66, 0x3f, 0x45, 0x4c, 0x4c, + 0x4a, 0x49, 0x49, 0x4a, 0x40, 0x52, 0x3e, 0x4c, 0x49, 0x40, 0x44, 0x49, + 0x48, 0x3f, 0x45, 0x5b, 0x49, 0x4b, 0x4c, 0x44, 0x50, 0x4e, 0x4a, 0x4a, + 0x49, 0x4e, 0x4f, 0x47, 0x46, 0x4b, 0x44, 0x3b, 0x4e, 0x4b, 0x48, 0x46, + 0x45, 0x45, 0x3d, 0x35, 0x4c, 0x49, 0x54, 0x42, 0x51, 0x46, 0x49, 0x2d, + 0x43, 0x4a, 0x53, 0x49, 0x49, 0x42, 0x4f, 0x40, 0x4e, 0x50, 0x54, 0x51, + 0x4b, 0x45, 0x48, 0x35, 0x4d, 0x41, 0x51, 0x40, 0x41, 0x49, 0x4a, 0x3b, + 0x45, 0x50, 0x48, 0x51, 0x51, 0x4d, 0x4c, 0x36, 0x47, 0x4a, 0x44, 0x45, + 0x4d, 0x47, 0x43, 0x3a, 0x48, 0x40, 0x42, 0x4f, 0x4f, 0x4f, 0x4f, 0x43, + 0x4a, 0x41, 0x4b, 0x53, 0x43, 0x46, 0x4f, 0x39, 0x46, 0x4a, 0x4d, 0x53, + 0x41, 0x44, 0x4e, 0x44, 0x3f, 0x47, 0x4c, 0x4d, 0x4d, 0x43, 0x45, 0x3d, + 0x43, 0x4b, 0x3e, 0x48, 0x42, 0x4c, 0x47, 0x42, 0x42, 0x50, 0x49, 0x4b, + 0x43, 0x4e, 0x44, 0x44, 0x4c, 0x3d, 0x4c, 0x47, 0x4e, 0x42, 0x4b, 0x44, + 0x4b, 0x44, 0x3f, 0x49, 0x33, 0x46, 0x4a, 0x4a, 0x42, 0x57, 0x5e, 0x4a, + 0x46, 0x4f, 0x55, 0x3c, 0x4a, 0x4b, 0x4c, 0x43, 0x51, 0x59, 0x64, 0x51, + 0x45, 0x60, 0x4b, 0x65, 0x46, 0x4a, 0x4e, 0x49, 0x41, 0x4b, 0x50, 0x5c, + 0x48, 0x4b, 0x3e, 0x52, 0x4f, 0x2f, 0x4e, 0x4a, 0x45, 0x53, 0x48, 0x59, + 0x4c, 0x4e, 0x4a, 0x4d, 0x49, 0x40, 0x52, 0x44, 0x49, 0x46, 0x4e, 0x46, + 0x42, 0x4b, 0x4a, 0x4b, 0x4b, 0x4b, 0x4f, 0x52, 0x46, 0x50, 0x4d, 0x3d, + 0x46, 0x4b, 0x4b, 0x40, 0x4d, 0x3f, 0x43, 0x33, 0x4e, 0x53, 0x4b, 0x4a, + 0x45, 0x48, 0x4c, 0x2e, 0x48, 0x4f, 0x49, 0x42, 0x54, 0x4f, 0x4b, 0x2b, + 0x55, 0x4e, 0x43, 0x4d, 0x4d, 0x47, 0x42, 0x3e, 0x48, 0x48, 0x4d, 0x54, + 0x52, 0x4f, 0x43, 0x37, 0x4b, 0x42, 0x4b, 0x4e, 0x49, 0x49, 0x4b, 0x2e, + 0x45, 0x4e, 0x48, 0x4e, 0x44, 0x49, 0x48, 0x30, 0x4c, 0x4b, 0x3f, 0x42, + 0x4f, 0x4f, 0x4e, 0x38, 0x4f, 0x42, 0x54, 0x49, 0x41, 0x42, 0x45, 0x3a, + 0x47, 0x43, 0x43, 0x4b, 0x49, 0x40, 0x4d, 0x38, 0x52, 0x4c, 0x3d, 0x4d, + 0x43, 0x54, 0x4e, 0x41, 0x4a, 0x47, 0x44, 0x51, 0x47, 0x48, 0x41, 0x47, + 0x4d, 0x41, 0x46, 0x4c, 0x4d, 0x46, 0x51, 0x4a, 0x49, 0x46, 0x4a, 0x42, + 0x3a, 0x43, 0x4a, 0x4b, 0x43, 0x4c, 0x68, 0x44, 0x4b, 0x52, 0x50, 0x37, + 0x4d, 0x4c, 0x57, 0x4c, 0x68, 0x62, 0x64, 0x4a, 0x3e, 0x64, 0x4b, 0x66, + 0x48, 0x4d, 0x54, 0x57, 0x4b, 0x52, 0x49, 0x5c, 0x4d, 0x55, 0x51, 0x57, + 0x4c, 0x3a, 0x48, 0x43, 0x3b, 0x43, 0x52, 0x5d, 0x45, 0x4e, 0x51, 0x4d, + 0x4a, 0x55, 0x4e, 0x4c, 0x44, 0x51, 0x4c, 0x4f, 0x41, 0x4f, 0x4a, 0x43, + 0x53, 0x48, 0x47, 0x49, 0x46, 0x52, 0x48, 0x3e, 0x4b, 0x4e, 0x4a, 0x50, + 0x4f, 0x47, 0x3e, 0x2e, 0x4b, 0x51, 0x4a, 0x44, 0x4c, 0x49, 0x4f, 0x26, + 0x48, 0x4f, 0x44, 0x51, 0x48, 0x3f, 0x4c, 0x30, 0x4e, 0x48, 0x4d, 0x48, + 0x48, 0x44, 0x4b, 0x2f, 0x50, 0x41, 0x4d, 0x50, 0x52, 0x42, 0x45, 0x33, + 0x4c, 0x48, 0x48, 0x3d, 0x46, 0x41, 0x43, 0x38, 0x45, 0x4f, 0x48, 0x4b, + 0x41, 0x49, 0x4c, 0x2f, 0x53, 0x4c, 0x48, 0x4a, 0x47, 0x40, 0x4a, 0x31, + 0x52, 0x40, 0x49, 0x4c, 0x3f, 0x48, 0x48, 0x39, 0x48, 0x3f, 0x45, 0x43, + 0x40, 0x48, 0x3c, 0x40, 0x4c, 0x48, 0x48, 0x4d, 0x3e, 0x42, 0x4a, 0x3d, + 0x4c, 0x45, 0x44, 0x46, 0x44, 0x45, 0x4a, 0x47, 0x52, 0x48, 0x4a, 0x4d, + 0x3f, 0x49, 0x4c, 0x4c, 0x48, 0x44, 0x4c, 0x44, 0x3d, 0x41, 0x47, 0x45, + 0x43, 0x4a, 0x5a, 0x3f, 0x48, 0x5d, 0x50, 0x35, 0x47, 0x4f, 0x5b, 0x46, + 0x6e, 0x50, 0x6d, 0x44, 0x49, 0x6a, 0x53, 0x6b, 0x4b, 0x4b, 0x4f, 0x62, + 0x45, 0x57, 0x48, 0x5b, 0x40, 0x4b, 0x4f, 0x63, 0x48, 0x3a, 0x4b, 0x42, + 0x43, 0x53, 0x41, 0x5f, 0x54, 0x3e, 0x4d, 0x43, 0x3d, 0x4c, 0x46, 0x46, + 0x49, 0x56, 0x4b, 0x45, 0x47, 0x45, 0x4e, 0x4f, 0x4c, 0x4d, 0x4f, 0x47, + 0x49, 0x4b, 0x51, 0x33, 0x4b, 0x45, 0x4d, 0x41, 0x51, 0x4a, 0x43, 0x2a, + 0x50, 0x4b, 0x4a, 0x4b, 0x4c, 0x52, 0x4c, 0x3b, 0x45, 0x4c, 0x51, 0x44, + 0x4c, 0x48, 0x43, 0x35, 0x51, 0x50, 0x48, 0x49, 0x3f, 0x48, 0x3d, 0x3b, + 0x52, 0x3f, 0x42, 0x4b, 0x49, 0x49, 0x47, 0x38, 0x4a, 0x4a, 0x41, 0x52, + 0x41, 0x3e, 0x4b, 0x2f, 0x46, 0x4d, 0x49, 0x44, 0x46, 0x3b, 0x47, 0x36, + 0x46, 0x3f, 0x49, 0x48, 0x47, 0x42, 0x42, 0x35, 0x44, 0x4b, 0x4d, 0x56, + 0x50, 0x49, 0x43, 0x42, 0x4b, 0x3e, 0x53, 0x44, 0x4a, 0x43, 0x47, 0x38, + 0x4a, 0x45, 0x4d, 0x3f, 0x46, 0x4a, 0x47, 0x3a, 0x4c, 0x3e, 0x47, 0x45, + 0x46, 0x4b, 0x45, 0x49, 0x4a, 0x4b, 0x54, 0x49, 0x4a, 0x53, 0x4a, 0x4c, + 0x45, 0x48, 0x53, 0x42, 0x4b, 0x47, 0x4e, 0x50, 0x3d, 0x51, 0x60, 0x3e, + 0x53, 0x5d, 0x51, 0x30, 0x45, 0x50, 0x59, 0x4e, 0x62, 0x52, 0x68, 0x51, + 0x45, 0x6c, 0x4c, 0x64, 0x4d, 0x47, 0x55, 0x61, 0x44, 0x57, 0x44, 0x58, + 0x44, 0x4a, 0x53, 0x58, 0x47, 0x31, 0x3f, 0x4c, 0x43, 0x45, 0x48, 0x5e, + 0x41, 0x43, 0x3f, 0x43, 0x51, 0x46, 0x48, 0x4b, 0x4d, 0x5b, 0x45, 0x4b, + 0x48, 0x46, 0x3f, 0x45, 0x47, 0x45, 0x40, 0x4a, 0x51, 0x51, 0x3d, 0x3f, + 0x43, 0x45, 0x4d, 0x4a, 0x47, 0x50, 0x49, 0x32, 0x4c, 0x5a, 0x55, 0x4f, + 0x4c, 0x51, 0x43, 0x37, 0x40, 0x59, 0x49, 0x49, 0x4e, 0x4f, 0x47, 0x34, + 0x40, 0x4c, 0x4a, 0x41, 0x4a, 0x47, 0x4a, 0x42, 0x4e, 0x4a, 0x48, 0x4e, + 0x4e, 0x4e, 0x45, 0x39, 0x4e, 0x45, 0x45, 0x4e, 0x4c, 0x48, 0x4a, 0x35, + 0x45, 0x4c, 0x49, 0x4f, 0x51, 0x43, 0x3c, 0x3a, 0x4a, 0x4a, 0x46, 0x48, + 0x49, 0x42, 0x4e, 0x2f, 0x42, 0x4e, 0x45, 0x50, 0x51, 0x40, 0x45, 0x32, + 0x4a, 0x4d, 0x44, 0x4e, 0x48, 0x48, 0x47, 0x2f, 0x48, 0x4b, 0x49, 0x44, + 0x48, 0x4d, 0x46, 0x3b, 0x46, 0x4a, 0x41, 0x4e, 0x4e, 0x47, 0x54, 0x4b, + 0x45, 0x49, 0x45, 0x44, 0x45, 0x48, 0x4a, 0x46, 0x55, 0x49, 0x47, 0x49, + 0x4b, 0x42, 0x48, 0x4f, 0x3f, 0x52, 0x60, 0x39, 0x4b, 0x5e, 0x55, 0x2e, + 0x48, 0x50, 0x59, 0x4f, 0x68, 0x5f, 0x64, 0x4f, 0x3b, 0x71, 0x50, 0x63, + 0x4f, 0x50, 0x50, 0x6c, 0x4b, 0x55, 0x47, 0x5b, 0x4c, 0x40, 0x48, 0x59, + 0x4f, 0x2e, 0x4b, 0x4c, 0x4e, 0x4e, 0x46, 0x61, 0x50, 0x41, 0x4c, 0x4a, + 0x44, 0x3e, 0x3f, 0x47, 0x4b, 0x4f, 0x47, 0x4b, 0x47, 0x3d, 0x41, 0x49, + 0x49, 0x3f, 0x4d, 0x44, 0x4a, 0x4d, 0x45, 0x41, 0x4d, 0x43, 0x49, 0x3c, + 0x49, 0x57, 0x49, 0x3b, 0x49, 0x59, 0x3f, 0x4f, 0x4e, 0x49, 0x4e, 0x46, + 0x52, 0x4e, 0x4c, 0x54, 0x4a, 0x48, 0x48, 0x3a, 0x44, 0x4a, 0x4f, 0x4a, + 0x44, 0x4b, 0x43, 0x4d, 0x51, 0x42, 0x53, 0x4d, 0x52, 0x41, 0x4d, 0x43, + 0x4e, 0x54, 0x4b, 0x42, 0x4b, 0x3f, 0x53, 0x45, 0x3f, 0x4a, 0x45, 0x50, + 0x3f, 0x4c, 0x4f, 0x43, 0x46, 0x42, 0x4b, 0x4d, 0x4c, 0x3b, 0x48, 0x40, + 0x4e, 0x4e, 0x49, 0x46, 0x4d, 0x4d, 0x52, 0x40, 0x4e, 0x4f, 0x46, 0x4a, + 0x40, 0x4b, 0x4c, 0x40, 0x4f, 0x4a, 0x44, 0x41, 0x46, 0x3c, 0x40, 0x3d, + 0x44, 0x48, 0x4a, 0x50, 0x46, 0x53, 0x46, 0x40, 0x44, 0x3e, 0x47, 0x43, + 0x48, 0x3d, 0x4e, 0x3e, 0x48, 0x49, 0x4b, 0x49, 0x4c, 0x3e, 0x4c, 0x4a, + 0x46, 0x4e, 0x62, 0x3c, 0x59, 0x60, 0x51, 0x29, 0x47, 0x52, 0x59, 0x4c, + 0x67, 0x68, 0x68, 0x4e, 0x3b, 0x72, 0x4d, 0x68, 0x44, 0x4f, 0x53, 0x63, + 0x47, 0x5a, 0x45, 0x4f, 0x4b, 0x37, 0x43, 0x5b, 0x4b, 0x3d, 0x44, 0x41, + 0x4a, 0x4b, 0x3c, 0x64, 0x48, 0x38, 0x42, 0x3f, 0x48, 0x46, 0x4b, 0x46, + 0x46, 0x4f, 0x46, 0x46, 0x44, 0x3c, 0x4b, 0x4f, 0x4d, 0x4a, 0x4b, 0x46, + 0x4d, 0x4f, 0x4f, 0x3f, 0x3a, 0x4b, 0x55, 0x3c, 0x51, 0x56, 0x4d, 0x42, + 0x52, 0x5a, 0x3e, 0x4b, 0x54, 0x57, 0x4e, 0x4d, 0x4e, 0x5b, 0x4e, 0x49, + 0x4e, 0x3c, 0x40, 0x41, 0x40, 0x4d, 0x48, 0x42, 0x49, 0x4e, 0x4f, 0x47, + 0x47, 0x48, 0x50, 0x49, 0x51, 0x46, 0x44, 0x45, 0x49, 0x46, 0x43, 0x48, + 0x48, 0x49, 0x4d, 0x4c, 0x45, 0x4f, 0x4c, 0x45, 0x44, 0x40, 0x49, 0x45, + 0x49, 0x51, 0x4b, 0x4b, 0x50, 0x4b, 0x48, 0x3d, 0x4e, 0x52, 0x4a, 0x47, + 0x49, 0x41, 0x55, 0x3d, 0x48, 0x4d, 0x49, 0x48, 0x4e, 0x4c, 0x48, 0x3d, + 0x3f, 0x4c, 0x4e, 0x53, 0x3e, 0x48, 0x4a, 0x3f, 0x54, 0x4d, 0x54, 0x4b, + 0x47, 0x4e, 0x44, 0x48, 0x49, 0x4b, 0x4c, 0x49, 0x4d, 0x42, 0x52, 0x4b, + 0x40, 0x3e, 0x54, 0x49, 0x55, 0x45, 0x47, 0x4d, 0x45, 0x5c, 0x60, 0x40, + 0x57, 0x60, 0x5b, 0x27, 0x4a, 0x5a, 0x64, 0x53, 0x6a, 0x5a, 0x5f, 0x52, + 0x3a, 0x72, 0x4b, 0x5f, 0x45, 0x56, 0x5f, 0x5f, 0x54, 0x5f, 0x39, 0x52, + 0x51, 0x3e, 0x3b, 0x5a, 0x44, 0x32, 0x46, 0x50, 0x3a, 0x4f, 0x44, 0x5d, + 0x4c, 0x41, 0x39, 0x3f, 0x45, 0x46, 0x3b, 0x43, 0x46, 0x51, 0x3c, 0x4c, + 0x4b, 0x43, 0x4b, 0x51, 0x43, 0x48, 0x4d, 0x43, 0x38, 0x46, 0x46, 0x43, + 0x44, 0x4a, 0x46, 0x49, 0x48, 0x50, 0x4e, 0x4a, 0x4e, 0x58, 0x4a, 0x49, + 0x48, 0x4f, 0x4a, 0x49, 0x41, 0x57, 0x51, 0x50, 0x4b, 0x48, 0x47, 0x4b, + 0x53, 0x3d, 0x4b, 0x4c, 0x4b, 0x4b, 0x55, 0x56, 0x45, 0x49, 0x46, 0x4c, + 0x45, 0x51, 0x47, 0x50, 0x40, 0x4b, 0x4f, 0x4b, 0x4d, 0x4a, 0x4f, 0x50, + 0x49, 0x53, 0x50, 0x46, 0x40, 0x48, 0x4a, 0x4a, 0x49, 0x4a, 0x42, 0x45, + 0x4b, 0x45, 0x42, 0x45, 0x4e, 0x4e, 0x44, 0x41, 0x4b, 0x4a, 0x49, 0x3f, + 0x41, 0x51, 0x48, 0x4c, 0x40, 0x41, 0x51, 0x42, 0x49, 0x49, 0x48, 0x42, + 0x48, 0x4c, 0x4b, 0x3c, 0x49, 0x45, 0x42, 0x49, 0x4c, 0x46, 0x45, 0x43, + 0x43, 0x48, 0x48, 0x41, 0x43, 0x42, 0x4c, 0x4b, 0x40, 0x45, 0x44, 0x46, + 0x4c, 0x4b, 0x4e, 0x4d, 0x3f, 0x59, 0x55, 0x41, 0x56, 0x5a, 0x51, 0x30, + 0x49, 0x5a, 0x63, 0x4d, 0x61, 0x5b, 0x64, 0x55, 0x34, 0x7a, 0x4c, 0x62, + 0x3e, 0x5d, 0x56, 0x60, 0x48, 0x61, 0x3f, 0x54, 0x46, 0x40, 0x42, 0x56, + 0x52, 0x35, 0x4c, 0x59, 0x45, 0x4c, 0x42, 0x60, 0x49, 0x3f, 0x4c, 0x3c, + 0x52, 0x36, 0x46, 0x3d, 0x58, 0x4b, 0x41, 0x48, 0x3e, 0x45, 0x4e, 0x54, + 0x4c, 0x56, 0x47, 0x44, 0x39, 0x4a, 0x4a, 0x4a, 0x46, 0x48, 0x4a, 0x48, + 0x51, 0x4f, 0x4b, 0x49, 0x45, 0x4b, 0x44, 0x4c, 0x3e, 0x4c, 0x42, 0x59, + 0x47, 0x55, 0x47, 0x47, 0x41, 0x44, 0x44, 0x4a, 0x44, 0x4b, 0x44, 0x46, + 0x49, 0x5a, 0x48, 0x5d, 0x4f, 0x4a, 0x47, 0x50, 0x48, 0x4e, 0x44, 0x57, + 0x49, 0x46, 0x42, 0x4d, 0x3d, 0x4a, 0x4a, 0x58, 0x41, 0x4d, 0x3c, 0x47, + 0x42, 0x4e, 0x4d, 0x49, 0x44, 0x4b, 0x4c, 0x4b, 0x53, 0x42, 0x4a, 0x46, + 0x4e, 0x56, 0x4b, 0x47, 0x50, 0x43, 0x4f, 0x48, 0x49, 0x50, 0x48, 0x50, + 0x42, 0x4c, 0x4e, 0x3c, 0x41, 0x4f, 0x4a, 0x41, 0x44, 0x47, 0x4c, 0x42, + 0x51, 0x4f, 0x53, 0x46, 0x4c, 0x4b, 0x48, 0x51, 0x47, 0x4b, 0x4c, 0x4d, + 0x4d, 0x49, 0x3d, 0x44, 0x4b, 0x42, 0x43, 0x49, 0x51, 0x47, 0x4c, 0x4b, + 0x4a, 0x50, 0x5b, 0x43, 0x5b, 0x68, 0x54, 0x31, 0x4c, 0x5d, 0x5c, 0x54, + 0x63, 0x5a, 0x61, 0x54, 0x3d, 0x7a, 0x51, 0x5b, 0x40, 0x59, 0x5a, 0x62, + 0x4c, 0x5e, 0x42, 0x58, 0x49, 0x3c, 0x38, 0x50, 0x54, 0x37, 0x42, 0x51, + 0x4d, 0x4f, 0x42, 0x68, 0x4a, 0x40, 0x4e, 0x40, 0x3f, 0x3e, 0x3f, 0x40, + 0x54, 0x52, 0x3e, 0x43, 0x46, 0x4a, 0x48, 0x51, 0x4e, 0x4d, 0x42, 0x47, + 0x3f, 0x51, 0x47, 0x44, 0x3f, 0x4c, 0x46, 0x47, 0x4f, 0x55, 0x4b, 0x4e, + 0x4c, 0x51, 0x40, 0x51, 0x47, 0x4a, 0x44, 0x5c, 0x48, 0x54, 0x4b, 0x46, + 0x49, 0x4b, 0x53, 0x59, 0x43, 0x3e, 0x45, 0x4e, 0x4f, 0x58, 0x4b, 0x64, + 0x41, 0x4b, 0x45, 0x4a, 0x4c, 0x51, 0x47, 0x57, 0x45, 0x46, 0x43, 0x4f, + 0x4d, 0x4d, 0x49, 0x58, 0x4b, 0x52, 0x43, 0x4b, 0x45, 0x4c, 0x50, 0x4c, + 0x4e, 0x4b, 0x40, 0x4c, 0x44, 0x4e, 0x4c, 0x47, 0x41, 0x55, 0x45, 0x4a, + 0x4c, 0x48, 0x46, 0x41, 0x47, 0x52, 0x44, 0x4f, 0x48, 0x49, 0x4b, 0x47, + 0x50, 0x4f, 0x42, 0x4a, 0x44, 0x4b, 0x52, 0x43, 0x45, 0x4e, 0x46, 0x49, + 0x45, 0x52, 0x51, 0x45, 0x44, 0x41, 0x4c, 0x46, 0x4c, 0x4b, 0x44, 0x4d, + 0x4f, 0x48, 0x44, 0x4d, 0x56, 0x48, 0x50, 0x4f, 0x3b, 0x4e, 0x55, 0x43, + 0x52, 0x62, 0x57, 0x2c, 0x4d, 0x5e, 0x5e, 0x50, 0x64, 0x5b, 0x6a, 0x55, + 0x39, 0x7d, 0x4b, 0x5e, 0x43, 0x54, 0x5d, 0x5c, 0x4d, 0x5c, 0x42, 0x51, + 0x4c, 0x3d, 0x46, 0x51, 0x4c, 0x2a, 0x3e, 0x54, 0x47, 0x48, 0x46, 0x64, + 0x42, 0x3d, 0x47, 0x3f, 0x42, 0x45, 0x49, 0x3b, 0x59, 0x50, 0x4c, 0x46, + 0x4d, 0x44, 0x47, 0x4d, 0x4a, 0x50, 0x41, 0x48, 0x43, 0x50, 0x3e, 0x44, + 0x4b, 0x53, 0x48, 0x49, 0x51, 0x51, 0x4d, 0x57, 0x49, 0x4f, 0x53, 0x50, + 0x46, 0x4f, 0x41, 0x5d, 0x47, 0x46, 0x49, 0x51, 0x45, 0x41, 0x4a, 0x56, + 0x4f, 0x4e, 0x4d, 0x4a, 0x3e, 0x55, 0x47, 0x65, 0x48, 0x51, 0x4d, 0x4e, + 0x46, 0x43, 0x48, 0x5b, 0x48, 0x4f, 0x4f, 0x48, 0x4b, 0x4d, 0x4e, 0x5c, + 0x4f, 0x4c, 0x54, 0x48, 0x4a, 0x4d, 0x4e, 0x4e, 0x44, 0x48, 0x43, 0x52, + 0x41, 0x52, 0x48, 0x4f, 0x46, 0x4f, 0x51, 0x41, 0x44, 0x45, 0x41, 0x4b, + 0x43, 0x4e, 0x4e, 0x42, 0x48, 0x41, 0x45, 0x43, 0x44, 0x43, 0x4c, 0x4c, + 0x51, 0x54, 0x4c, 0x32, 0x46, 0x52, 0x4e, 0x49, 0x40, 0x4d, 0x43, 0x4f, + 0x4a, 0x4d, 0x4d, 0x49, 0x46, 0x4c, 0x41, 0x4d, 0x41, 0x3a, 0x50, 0x4c, + 0x5a, 0x4e, 0x49, 0x53, 0x4d, 0x53, 0x53, 0x3d, 0x52, 0x64, 0x55, 0x2a, + 0x47, 0x5d, 0x61, 0x51, 0x5b, 0x5d, 0x66, 0x52, 0x3f, 0xfd, 0x55, 0x5a, + 0x4b, 0x54, 0x5b, 0x60, 0x49, 0x5d, 0x43, 0x57, 0x47, 0x41, 0x45, 0x5e, + 0x4c, 0x28, 0x3e, 0x40, 0x49, 0x4e, 0x40, 0x69, 0x4a, 0x44, 0x45, 0x43, + 0x45, 0x3d, 0x39, 0x40, 0x4c, 0x53, 0x4b, 0x3d, 0x4e, 0x43, 0x48, 0x55, + 0x4d, 0x50, 0x4d, 0x49, 0x4f, 0x48, 0x3e, 0x46, 0x47, 0x56, 0x40, 0x48, + 0x46, 0x53, 0x50, 0x5d, 0x43, 0x54, 0x49, 0x47, 0x49, 0x4c, 0x48, 0x5d, + 0x49, 0x51, 0x50, 0x3d, 0x41, 0x47, 0x48, 0x64, 0x4b, 0x44, 0x49, 0x41, + 0x54, 0x48, 0x3d, 0x6b, 0x4c, 0x5a, 0x48, 0x4e, 0x40, 0x4c, 0x52, 0x5f, + 0x54, 0x4a, 0x3f, 0x48, 0x43, 0x43, 0x44, 0x66, 0x49, 0x47, 0x43, 0x46, + 0x47, 0x54, 0x42, 0x54, 0x4b, 0x4e, 0x49, 0x49, 0x49, 0x4b, 0x52, 0x4f, + 0x43, 0x46, 0x4b, 0x49, 0x54, 0x4b, 0x40, 0x48, 0x47, 0x4a, 0x46, 0x47, + 0x44, 0x47, 0x4c, 0x37, 0x3f, 0x49, 0x45, 0x44, 0x50, 0x49, 0x44, 0x36, + 0x4d, 0x40, 0x45, 0x49, 0x53, 0x55, 0x44, 0x42, 0x47, 0x48, 0x46, 0x40, + 0x4f, 0x4c, 0x41, 0x42, 0x52, 0x3a, 0x43, 0x46, 0x55, 0x51, 0x4e, 0x4f, + 0x48, 0x51, 0x55, 0x48, 0x52, 0x66, 0x4e, 0x33, 0x49, 0x5b, 0x5f, 0x4b, + 0x5f, 0x5b, 0x66, 0x52, 0x41, 0x7c, 0x4a, 0x59, 0x47, 0x59, 0x58, 0x67, + 0x49, 0x5e, 0x44, 0x57, 0x49, 0x4c, 0x43, 0x56, 0x41, 0x27, 0x4c, 0x44, + 0x51, 0x44, 0x42, 0x65, 0x49, 0x44, 0x40, 0x3d, 0x4d, 0x3e, 0x4c, 0x3c, + 0x4f, 0x4b, 0x45, 0x44, 0x4d, 0x48, 0x47, 0x54, 0x4d, 0x4e, 0x44, 0x42, + 0x47, 0x44, 0x3d, 0x49, 0x4e, 0x50, 0x49, 0x45, 0x58, 0x4a, 0x54, 0x5c, + 0x41, 0x49, 0x4f, 0x42, 0x44, 0x4f, 0x4a, 0x62, 0x48, 0x50, 0x48, 0x43, + 0x51, 0x53, 0x47, 0x6c, 0x40, 0x46, 0x3d, 0x46, 0x4a, 0x50, 0x43, 0x69, + 0x49, 0x4f, 0x4a, 0x4c, 0x49, 0x46, 0x43, 0x6a, 0x48, 0x50, 0x49, 0x48, + 0x48, 0x51, 0x4b, 0x65, 0x42, 0x4b, 0x4d, 0x48, 0x44, 0x4e, 0x49, 0x60, + 0x44, 0x52, 0x42, 0x42, 0x47, 0x48, 0x4b, 0x51, 0x50, 0x4b, 0x3c, 0x4d, + 0x4c, 0x44, 0x48, 0x55, 0x51, 0x4c, 0x55, 0x4e, 0x52, 0x4c, 0x4b, 0x39, + 0x48, 0x42, 0x49, 0x49, 0x49, 0x50, 0x49, 0x32, 0x4e, 0x4b, 0x45, 0x4f, + 0x42, 0x4b, 0x47, 0x50, 0x48, 0x45, 0x54, 0x49, 0x4c, 0x46, 0x40, 0x46, + 0x43, 0x3d, 0x51, 0x44, 0x53, 0x4f, 0x54, 0x55, 0x43, 0x4f, 0x5b, 0x47, + 0x53, 0x6c, 0x57, 0x2e, 0x50, 0x55, 0x5a, 0x4d, 0x57, 0x5d, 0x70, 0x50, + 0x3f, 0x79, 0x4a, 0x5a, 0x4c, 0x58, 0x59, 0x63, 0x45, 0x69, 0x48, 0x58, + 0x42, 0x4b, 0x43, 0x5c, 0x46, 0x28, 0x48, 0x49, 0x4c, 0x3f, 0x45, 0x58, + 0x45, 0x44, 0x47, 0x40, 0x4c, 0x42, 0x3e, 0x37, 0x45, 0x54, 0x48, 0x3b, + 0x4e, 0x48, 0x43, 0x4a, 0x50, 0x4a, 0x49, 0x46, 0x4c, 0x54, 0x3f, 0x4b, + 0x4e, 0x56, 0x48, 0x49, 0x49, 0x4c, 0x51, 0x5f, 0x4d, 0x4b, 0x43, 0x4d, + 0x47, 0x51, 0x43, 0x59, 0x45, 0x4e, 0x4f, 0x45, 0x44, 0x54, 0x44, 0x6d, + 0x47, 0x51, 0x43, 0x4e, 0x4c, 0x4f, 0x43, 0x6d, 0x48, 0x53, 0x4b, 0x47, + 0x49, 0x48, 0x46, 0x6a, 0x51, 0x4c, 0x4d, 0x45, 0x4e, 0x47, 0x46, 0x62, + 0x4a, 0x54, 0x51, 0x4c, 0x47, 0x4d, 0x4a, 0x61, 0x3d, 0x50, 0x4c, 0x4c, + 0x45, 0x3f, 0x3e, 0x54, 0x3d, 0x53, 0x48, 0x47, 0x52, 0x4b, 0x47, 0x51, + 0x4f, 0x45, 0x4b, 0x4a, 0x4c, 0x46, 0x44, 0x37, 0x42, 0x50, 0x49, 0x4f, + 0x51, 0x41, 0x44, 0x38, 0x54, 0x40, 0x51, 0x52, 0x3e, 0x43, 0x44, 0x47, + 0x49, 0x4b, 0x4b, 0x46, 0x53, 0x54, 0x55, 0x4b, 0x4a, 0x37, 0x43, 0x4a, + 0x51, 0x47, 0x51, 0x54, 0x43, 0x46, 0x56, 0x3d, 0x54, 0x66, 0x4f, 0x30, + 0x45, 0x52, 0x5a, 0x43, 0x5c, 0x65, 0x5d, 0x52, 0x32, 0x77, 0x53, 0x5f, + 0x4a, 0x5a, 0x4f, 0x5e, 0x4e, 0x61, 0x4b, 0x5b, 0x4a, 0x53, 0x3e, 0x61, + 0x47, 0x24, 0x3e, 0x48, 0x4d, 0x43, 0x40, 0x53, 0x4e, 0x41, 0x43, 0x3d, + 0x50, 0x49, 0x41, 0x3a, 0x4e, 0x4b, 0x48, 0x49, 0x48, 0x49, 0x46, 0x50, + 0x4f, 0x4b, 0x47, 0x4b, 0x48, 0x52, 0x3e, 0x4d, 0x4d, 0x59, 0x4c, 0x3e, + 0x52, 0x49, 0x4f, 0x5e, 0x54, 0x59, 0x47, 0x4d, 0x40, 0x4c, 0x4b, 0x64, + 0x42, 0x4c, 0x53, 0x46, 0x4e, 0x50, 0x46, 0x6a, 0x41, 0x59, 0x44, 0x4b, + 0x4f, 0x44, 0x52, 0x6c, 0x54, 0x4e, 0x46, 0x48, 0x42, 0x3d, 0x44, 0x67, + 0x44, 0x4f, 0x47, 0x54, 0x4c, 0x4f, 0x43, 0x61, 0x4c, 0x54, 0x4f, 0x43, + 0x49, 0x40, 0x4a, 0x5f, 0x4a, 0x52, 0x47, 0x43, 0x4c, 0x43, 0x49, 0x53, + 0x4c, 0x4b, 0x43, 0x3d, 0x4e, 0x45, 0x49, 0x50, 0x44, 0x53, 0x4f, 0x48, + 0x4b, 0x46, 0x44, 0x3c, 0x50, 0x42, 0x43, 0x40, 0x47, 0x43, 0x42, 0x34, + 0x47, 0x42, 0x3f, 0x4a, 0x48, 0x42, 0x48, 0x4c, 0x42, 0x4c, 0x4e, 0x47, + 0x48, 0x47, 0x51, 0x51, 0x4d, 0x3d, 0x3e, 0x4b, 0x54, 0x4c, 0x4c, 0x59, + 0x4f, 0x50, 0x57, 0x3c, 0x54, 0x62, 0x54, 0x35, 0x3d, 0x5a, 0x5b, 0x47, + 0x59, 0x63, 0x66, 0x4d, 0x3c, 0x79, 0x50, 0x5f, 0x45, 0x58, 0x4e, 0x5d, + 0x48, 0x61, 0x43, 0x54, 0x47, 0x54, 0x4d, 0x54, 0x4b, 0x25, 0x41, 0x44, + 0x4c, 0x4a, 0x3b, 0x52, 0x47, 0x3c, 0x45, 0x3c, 0x53, 0x44, 0x44, 0x40, + 0x50, 0x4c, 0x45, 0x3a, 0x4c, 0x51, 0x44, 0x49, 0x4d, 0x52, 0x4d, 0x4b, + 0x45, 0x52, 0x3d, 0x50, 0x4a, 0x58, 0x4a, 0x47, 0x4d, 0x47, 0x4e, 0x52, + 0x4f, 0x4d, 0x4f, 0x49, 0x52, 0x52, 0x4c, 0x5e, 0x47, 0x4d, 0x46, 0x4d, + 0x4c, 0x48, 0x50, 0x70, 0x41, 0x4a, 0x48, 0x3d, 0x45, 0x48, 0x45, 0x74, + 0x47, 0x4c, 0x43, 0x4f, 0x4a, 0x4a, 0x40, 0x68, 0x52, 0x49, 0x3e, 0x3e, + 0x4e, 0x4b, 0x4b, 0x69, 0x42, 0x4f, 0x45, 0x47, 0x3f, 0x45, 0x46, 0x56, + 0x45, 0x4a, 0x47, 0x44, 0x52, 0x4b, 0x53, 0x4e, 0x4e, 0x46, 0x45, 0x40, + 0x47, 0x4b, 0x53, 0x52, 0x53, 0x51, 0x4f, 0x46, 0x42, 0x43, 0x50, 0x3e, + 0x48, 0x4e, 0x41, 0x53, 0x4d, 0x48, 0x48, 0x33, 0x40, 0x43, 0x4b, 0x42, + 0x52, 0x4c, 0x42, 0x4e, 0x41, 0x4e, 0x4f, 0x50, 0x43, 0x49, 0x4d, 0x47, + 0x4a, 0x3a, 0x3f, 0x51, 0x51, 0x44, 0x4e, 0x54, 0x40, 0x55, 0x59, 0x3c, + 0x57, 0x67, 0x4e, 0x2e, 0x4c, 0x5b, 0x5b, 0x51, 0x58, 0x63, 0x62, 0x52, + 0x3c, 0x72, 0x51, 0x5a, 0x4e, 0x53, 0x4a, 0x5c, 0x51, 0x69, 0x42, 0x51, + 0x48, 0x54, 0x48, 0x57, 0x3e, 0x37, 0x3f, 0x4d, 0x4d, 0x4a, 0x35, 0x57, + 0x4e, 0x40, 0x45, 0x4a, 0x45, 0x4e, 0x49, 0x40, 0x49, 0x53, 0x51, 0x44, + 0x4a, 0x50, 0x4b, 0x4b, 0x50, 0x4f, 0x3e, 0x44, 0x45, 0x44, 0x4c, 0x51, + 0x47, 0x51, 0x46, 0x42, 0x48, 0x50, 0x49, 0x4d, 0x43, 0x54, 0x52, 0x4d, + 0x4e, 0x4f, 0x3f, 0x63, 0x54, 0x57, 0x41, 0x44, 0x4e, 0x50, 0x4e, 0x66, + 0x41, 0x53, 0x4b, 0x4d, 0x4e, 0x4f, 0x43, 0x6d, 0x4e, 0x51, 0x49, 0x4f, + 0x49, 0x4a, 0x4a, 0x6c, 0x4b, 0x4f, 0x3d, 0x47, 0x4d, 0x51, 0x3c, 0x66, + 0x4b, 0x56, 0x3e, 0x4c, 0x41, 0x46, 0x45, 0x68, 0x47, 0x4b, 0x4a, 0x54, + 0x53, 0x48, 0x51, 0x59, 0x45, 0x43, 0x50, 0x45, 0x4f, 0x45, 0x42, 0x55, + 0x48, 0x52, 0x4c, 0x46, 0x52, 0x49, 0x47, 0x3d, 0x55, 0x48, 0x52, 0x52, + 0x40, 0x4e, 0x47, 0x31, 0x45, 0x4f, 0x42, 0x4a, 0x4e, 0x50, 0x42, 0x4a, + 0x49, 0x57, 0x46, 0x4b, 0x45, 0x4e, 0x4d, 0x46, 0x47, 0x43, 0x50, 0x4e, + 0x4f, 0x4c, 0x53, 0x55, 0x45, 0x51, 0x5b, 0x3a, 0x52, 0x64, 0x54, 0x2d, + 0x42, 0x59, 0x59, 0x45, 0x59, 0x67, 0x69, 0x53, 0x3f, 0x78, 0x50, 0x60, + 0x4c, 0x4c, 0x5b, 0x53, 0x45, 0x63, 0x49, 0x63, 0x51, 0x4c, 0x41, 0x4e, + 0x4b, 0x37, 0x45, 0x4e, 0x48, 0x4c, 0x39, 0x55, 0x44, 0x37, 0x3c, 0x49, + 0x44, 0x56, 0x3e, 0x40, 0x4d, 0x45, 0x4c, 0x43, 0x42, 0x41, 0x40, 0x42, + 0x57, 0x4f, 0x43, 0x3f, 0x52, 0x53, 0x51, 0x4b, 0x4b, 0x55, 0x46, 0x40, + 0x49, 0x45, 0x40, 0x4f, 0x47, 0x58, 0x4b, 0x53, 0x4e, 0x52, 0x54, 0x5e, + 0x4b, 0x51, 0x50, 0x44, 0x50, 0x4b, 0x4f, 0x70, 0x49, 0x4f, 0x4c, 0x50, + 0x45, 0x56, 0x4b, 0x6b, 0x49, 0x52, 0x4a, 0x3f, 0x44, 0x4b, 0x48, 0x72, + 0x4c, 0x47, 0x4e, 0x43, 0x46, 0x4c, 0x4f, 0x61, 0x4a, 0x52, 0x52, 0x46, + 0x4a, 0x4d, 0x46, 0x65, 0x48, 0x4e, 0x4d, 0x4e, 0x46, 0x4e, 0x53, 0x59, + 0x43, 0x49, 0x43, 0x47, 0x45, 0x47, 0x53, 0x50, 0x3e, 0x4d, 0x41, 0x46, + 0x4c, 0x4a, 0x4c, 0x35, 0x3f, 0x4f, 0x50, 0x48, 0x47, 0x4d, 0x4c, 0x32, + 0x45, 0x53, 0x43, 0x4d, 0x4e, 0x4a, 0x3e, 0x4b, 0x55, 0x4f, 0x53, 0x4c, + 0x4a, 0x4d, 0x48, 0x53, 0x4f, 0x3a, 0x47, 0x4b, 0x4e, 0x4e, 0x51, 0x59, + 0x41, 0x50, 0x57, 0x38, 0x5d, 0x63, 0x59, 0x2b, 0x45, 0x53, 0x5a, 0x4e, + 0x5c, 0x60, 0x5e, 0x4c, 0x41, 0x6f, 0x53, 0x5c, 0x48, 0x53, 0x56, 0x54, + 0x4b, 0x62, 0x46, 0x63, 0x47, 0x4e, 0x40, 0x51, 0x43, 0x36, 0x44, 0x42, + 0x46, 0x51, 0x41, 0x54, 0x4e, 0x36, 0x40, 0x4b, 0x55, 0x49, 0x40, 0x3f, + 0x4b, 0x42, 0x4a, 0x4a, 0x48, 0x47, 0x40, 0x43, 0x4d, 0x4f, 0x55, 0x3f, + 0x53, 0x42, 0x4d, 0x56, 0x49, 0x51, 0x4f, 0x41, 0x3b, 0x48, 0x43, 0x4e, + 0x4b, 0x5c, 0x4f, 0x45, 0x4a, 0x4c, 0x46, 0x66, 0x43, 0x45, 0x46, 0x48, + 0x4f, 0x4e, 0x40, 0x71, 0x4b, 0x4e, 0x3e, 0x42, 0x4d, 0x52, 0x42, 0x71, + 0x4c, 0x54, 0x4f, 0x3f, 0x4c, 0x43, 0x4a, 0x73, 0x48, 0x48, 0x4c, 0x4b, + 0x4c, 0x4d, 0x40, 0x72, 0x3e, 0x51, 0x49, 0x48, 0x52, 0x53, 0x45, 0x65, + 0x52, 0x4e, 0x4f, 0x44, 0x4c, 0x43, 0x4a, 0x5e, 0x3e, 0x56, 0x46, 0x55, + 0x55, 0x43, 0x49, 0x51, 0x4f, 0x52, 0x49, 0x4d, 0x46, 0x47, 0x49, 0x3e, + 0x51, 0x49, 0x41, 0x53, 0x42, 0x47, 0x46, 0x3b, 0x4d, 0x4e, 0x48, 0x44, + 0x42, 0x48, 0x4c, 0x47, 0x42, 0x4e, 0x4a, 0x3e, 0x44, 0x54, 0x4a, 0x4d, + 0x49, 0x41, 0x41, 0x53, 0x52, 0x4c, 0x4c, 0x56, 0x49, 0x4a, 0x5a, 0x3f, + 0x5b, 0x5c, 0x59, 0x2f, 0x49, 0x52, 0x5a, 0x4e, 0x5a, 0x61, 0x67, 0x4c, + 0x41, 0x6f, 0x5a, 0x5a, 0x40, 0x5a, 0x54, 0x4e, 0x49, 0x66, 0x45, 0x5a, + 0x4a, 0x45, 0x44, 0x4b, 0x44, 0x36, 0x41, 0x4c, 0x45, 0x44, 0x3d, 0x51, + 0x3f, 0x35, 0x3c, 0x46, 0x53, 0x5c, 0x3f, 0x3e, 0x50, 0x43, 0x46, 0x4b, + 0x40, 0x54, 0x41, 0x47, 0x4b, 0x51, 0x41, 0x46, 0x4a, 0x4d, 0x51, 0x52, + 0x43, 0x58, 0x45, 0x46, 0x4e, 0x46, 0x4a, 0x4b, 0x44, 0x54, 0x4c, 0x4c, + 0x43, 0x59, 0x48, 0x61, 0x4e, 0x4f, 0x4d, 0x4d, 0x4a, 0x52, 0x4c, 0x6e, + 0x49, 0x57, 0x48, 0x4d, 0x46, 0x46, 0x4d, 0x72, 0x4a, 0x4e, 0x47, 0x44, + 0x49, 0x4f, 0x48, 0x73, 0x42, 0x40, 0x4d, 0x44, 0x4d, 0x57, 0x3e, 0x69, + 0x50, 0x52, 0x4c, 0x55, 0x46, 0x4c, 0x44, 0x5f, 0x4b, 0x4d, 0x55, 0x4c, + 0x48, 0x49, 0x4a, 0x5e, 0x47, 0x4b, 0x45, 0x53, 0x55, 0x53, 0x4d, 0x53, + 0x47, 0x5c, 0x45, 0x4e, 0x4e, 0x52, 0x4c, 0x39, 0x4b, 0x4c, 0x49, 0x46, + 0x4a, 0x4e, 0x4b, 0x33, 0x46, 0x47, 0x52, 0x41, 0x49, 0x4b, 0x4c, 0x48, + 0x51, 0x53, 0x44, 0x4c, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x4b, 0x50, 0x47, + 0x4d, 0x4b, 0x4c, 0x4f, 0x44, 0x45, 0x58, 0x3c, 0x56, 0x5a, 0x56, 0x23, + 0x4f, 0x4d, 0x5c, 0x4e, 0x59, 0x5a, 0x65, 0x43, 0x45, 0x66, 0x54, 0x5f, + 0x45, 0x5e, 0x54, 0x4f, 0x48, 0x5f, 0x44, 0x59, 0x48, 0x46, 0x47, 0x49, + 0x4d, 0x3c, 0x49, 0x54, 0x3e, 0x48, 0x43, 0x5b, 0x4a, 0x35, 0x41, 0x43, + 0x4b, 0x55, 0x43, 0x38, 0x46, 0x42, 0x4a, 0x4e, 0x54, 0x4b, 0x4d, 0x46, + 0x43, 0x4e, 0x44, 0x47, 0x56, 0x4c, 0x51, 0x57, 0x41, 0x4d, 0x43, 0x41, + 0x51, 0x47, 0x41, 0x51, 0x51, 0x4f, 0x46, 0x50, 0x52, 0x4e, 0x4d, 0x60, + 0x41, 0x49, 0x46, 0x50, 0x48, 0x56, 0x42, 0x6d, 0x40, 0x45, 0x44, 0x55, + 0x40, 0x4e, 0x40, 0x7c, 0x47, 0x5a, 0x44, 0x44, 0x45, 0x56, 0x55, 0x71, + 0x47, 0x4b, 0x4b, 0x45, 0x4f, 0x54, 0x4c, 0x73, 0x48, 0x55, 0x44, 0x4d, + 0x4a, 0x47, 0x49, 0x5e, 0x4d, 0x52, 0x4e, 0x4c, 0x48, 0x52, 0x48, 0x58, + 0x4c, 0x5a, 0x49, 0x4b, 0x53, 0x46, 0x4d, 0x4b, 0x48, 0x53, 0x41, 0x49, + 0x4a, 0x56, 0x51, 0x3a, 0x4c, 0x4e, 0x4f, 0x51, 0x4c, 0x59, 0x47, 0x45, + 0x4f, 0x50, 0x4a, 0x4f, 0x4d, 0x3f, 0x44, 0x4e, 0x42, 0x4a, 0x4a, 0x43, + 0x46, 0x4e, 0x4c, 0x4f, 0x47, 0x47, 0x4c, 0x4b, 0x52, 0x50, 0x50, 0x4b, + 0x42, 0x45, 0x54, 0x44, 0x54, 0x59, 0x4c, 0x2b, 0x4d, 0x4c, 0x55, 0x4e, + 0x5c, 0x5b, 0x5a, 0x42, 0x47, 0x5e, 0x56, 0x59, 0x47, 0x65, 0x55, 0x4c, + 0x4c, 0x59, 0x42, 0x5a, 0x4e, 0x46, 0x4e, 0x4b, 0x53, 0x46, 0x49, 0x56, + 0x48, 0x58, 0x4b, 0x4f, 0x45, 0x38, 0x40, 0x44, 0x49, 0x51, 0x4a, 0x3b, + 0x53, 0x40, 0x40, 0x48, 0x51, 0x49, 0x44, 0x46, 0x52, 0x4b, 0x4e, 0x45, + 0x48, 0x5a, 0x4e, 0x57, 0x44, 0x53, 0x49, 0x40, 0x4c, 0x47, 0x41, 0x4f, + 0x49, 0x55, 0x46, 0x50, 0x57, 0x5b, 0x48, 0x66, 0x50, 0x49, 0x51, 0x55, + 0x55, 0x4f, 0x47, 0x72, 0x49, 0x4f, 0x41, 0x4c, 0x49, 0x42, 0x48, 0x75, + 0x4a, 0x55, 0x45, 0x4a, 0x41, 0x51, 0x41, 0x70, 0x47, 0x49, 0x42, 0x52, + 0x4f, 0x47, 0x46, 0x63, 0x4f, 0x53, 0x46, 0x4f, 0x49, 0x53, 0x52, 0x63, + 0x4c, 0x59, 0x46, 0x41, 0x49, 0x51, 0x3e, 0x53, 0x45, 0x52, 0x51, 0x40, + 0x4f, 0x4c, 0x41, 0x4c, 0x47, 0x4a, 0x46, 0x47, 0x53, 0x47, 0x48, 0x39, + 0x53, 0x4b, 0x46, 0x4b, 0x50, 0x4c, 0x41, 0x40, 0x48, 0x4e, 0x49, 0x4e, + 0x44, 0x53, 0x44, 0x4e, 0x53, 0x49, 0x49, 0x4e, 0x46, 0x3f, 0x45, 0x42, + 0x4c, 0x47, 0x42, 0x4e, 0x49, 0x4a, 0x49, 0x44, 0x51, 0x48, 0x57, 0x4c, + 0x4d, 0x60, 0x4e, 0x2d, 0x46, 0x4d, 0x58, 0x53, 0x5c, 0x56, 0x5e, 0x41, + 0x3e, 0x66, 0x53, 0x5b, 0x49, 0x59, 0x5a, 0x55, 0x4e, 0x59, 0x46, 0x4a, + 0x44, 0x42, 0x45, 0x3d, 0x4d, 0x45, 0x44, 0x4f, 0x4d, 0x53, 0x42, 0x5a, + 0x43, 0x3c, 0x48, 0x4f, 0x44, 0x59, 0x3f, 0x33, 0x45, 0x48, 0x43, 0x45, + 0x4d, 0x56, 0x48, 0x44, 0x3e, 0x48, 0x46, 0x4d, 0x44, 0x53, 0x46, 0x4e, + 0x45, 0x52, 0x40, 0x46, 0x4c, 0x50, 0x4e, 0x4b, 0x4d, 0x46, 0x48, 0x46, + 0x50, 0x52, 0x4e, 0x57, 0x3f, 0x4a, 0x49, 0x50, 0x53, 0x4e, 0x41, 0x66, + 0x49, 0x4f, 0x40, 0x4b, 0x50, 0x4c, 0x4a, 0x70, 0x42, 0x51, 0x41, 0x4c, + 0x50, 0x4f, 0x46, 0x60, 0x45, 0x47, 0x54, 0x4c, 0x49, 0x59, 0x52, 0x61, + 0x4a, 0x53, 0x52, 0x4f, 0x4b, 0x4c, 0x46, 0x56, 0x4b, 0x54, 0x4f, 0x47, + 0x53, 0x49, 0x4f, 0x50, 0x4a, 0x54, 0x45, 0x4e, 0x47, 0x48, 0x47, 0x42, + 0x49, 0x44, 0x46, 0x46, 0x55, 0x4c, 0x4f, 0x36, 0x4c, 0x49, 0x3f, 0x4e, + 0x45, 0x4b, 0x4b, 0x36, 0x48, 0x4f, 0x4b, 0x50, 0x45, 0x47, 0x49, 0x3f, + 0x50, 0x4b, 0x52, 0x48, 0x4c, 0x41, 0x49, 0x43, 0x4e, 0x3c, 0x43, 0x45, + 0x3e, 0x45, 0x48, 0x44, 0x4d, 0x48, 0x56, 0x47, 0x4b, 0x54, 0x52, 0x2b, + 0x4d, 0x4e, 0x57, 0x4f, 0x57, 0x4f, 0x56, 0x43, 0x48, 0x5f, 0x4c, 0x51, + 0x4d, 0x58, 0x4f, 0x4e, 0x50, 0x50, 0x48, 0x4a, 0x4d, 0x3f, 0x47, 0x40, + 0x4b, 0x4a, 0x4e, 0x4b, 0x4a, 0x58, 0x42, 0x49, 0x3f, 0x42, 0x3d, 0x4d, + 0x46, 0x53, 0x45, 0x3e, 0x4e, 0x49, 0x4f, 0x4a, 0x47, 0x46, 0x40, 0x3e, + 0x4c, 0x4d, 0x4d, 0x45, 0x4a, 0x56, 0x40, 0x4a, 0x47, 0x57, 0x4f, 0x48, + 0x4f, 0x48, 0x47, 0x49, 0x4e, 0x52, 0x50, 0x48, 0x42, 0x52, 0x43, 0x5a, + 0x49, 0x42, 0x4f, 0x4f, 0x51, 0x51, 0x50, 0x5c, 0x4b, 0x43, 0x4b, 0x48, + 0x50, 0x51, 0x4b, 0x6d, 0x53, 0x4e, 0x44, 0x4c, 0x4c, 0x51, 0x46, 0x5b, + 0x44, 0x48, 0x4d, 0x4c, 0x46, 0x4f, 0x54, 0x54, 0x4e, 0x54, 0x42, 0x4e, + 0x4c, 0x49, 0x49, 0x58, 0x49, 0x53, 0x53, 0x4a, 0x4e, 0x4b, 0x47, 0x53, + 0x43, 0x55, 0x46, 0x51, 0x3d, 0x3d, 0x4c, 0x47, 0x4e, 0x51, 0x47, 0x48, + 0x4b, 0x4c, 0x42, 0x3b, 0x43, 0x4f, 0x44, 0x4d, 0x54, 0x4b, 0x4a, 0x47, + 0x4c, 0x42, 0x4b, 0x43, 0x41, 0x4e, 0x4d, 0x50, 0x45, 0x46, 0x41, 0x4a, + 0x49, 0x49, 0x54, 0x47, 0x4c, 0x4b, 0x50, 0x4e, 0x3f, 0x43, 0x40, 0x41, + 0x44, 0x54, 0x51, 0x47, 0x4c, 0x4b, 0x4f, 0x34, 0x4d, 0x4c, 0x4f, 0x49, + 0x56, 0x4e, 0x4b, 0x3e, 0x48, 0x53, 0x4e, 0x56, 0x49, 0x4e, 0x4c, 0x40, + 0x55, 0x4a, 0x46, 0x4f, 0x48, 0x4a, 0x55, 0x41, 0x55, 0x3d, 0x47, 0x51, + 0x50, 0x51, 0x45, 0x51, 0x4b, 0x4e, 0x4a, 0x4f, 0x4b, 0x45, 0x42, 0x3c, + 0x4e, 0x46, 0x47, 0x49, 0x4a, 0x4c, 0x48, 0x41, 0x4f, 0x4a, 0x44, 0x45, + 0x4e, 0x4e, 0x43, 0x41, 0x4c, 0x47, 0x48, 0x49, 0x4c, 0x48, 0x4f, 0x4a, + 0x4f, 0x4a, 0x4b, 0x45, 0x42, 0x40, 0x52, 0x55, 0x4f, 0x49, 0x44, 0x54, + 0x49, 0x48, 0x51, 0x4d, 0x44, 0x4a, 0x4d, 0x49, 0x4e, 0x4e, 0x51, 0x5d, + 0x42, 0x4d, 0x49, 0x3f, 0x48, 0x58, 0x40, 0x5e, 0x48, 0x4f, 0x49, 0x53, + 0x45, 0x47, 0x4f, 0x53, 0x4d, 0x4f, 0x4d, 0x4d, 0x46, 0x55, 0x43, 0x51, + 0x4f, 0x51, 0x4a, 0x4e, 0x49, 0x42, 0x49, 0x50, 0x47, 0x4d, 0x42, 0x47, + 0x46, 0x50, 0x55, 0x47, 0x4d, 0x47, 0x3e, 0x51, 0x4d, 0x43, 0x44, 0x39, + 0x4e, 0x4b, 0x41, 0x48, 0x52, 0x53, 0x4d, 0x39, 0x4d, 0x51, 0x4c, 0x46, + 0x4e, 0x47, 0x49, 0x41, 0x45, 0x4a, 0x4a, 0x45, 0x50, 0x4a, 0x40, 0x48, + 0x43, 0x47, 0x44, 0x50, 0x4d, 0x47, 0x4a, 0x47, 0x45, 0x57, 0x41, 0x34, + 0x51, 0x40, 0x45, 0x44, 0x3c, 0x47, 0x46, 0x47, 0x44, 0x48, 0x42, 0x40, + 0x37, 0x53, 0x4a, 0x43, 0x49, 0x4b, 0x43, 0x44, 0x4f, 0x4f, 0x48, 0x48, + 0x53, 0x49, 0x4b, 0x48, 0x4e, 0x4c, 0x42, 0x45, 0x4c, 0x4a, 0x4a, 0x46, + 0x47, 0x57, 0x3e, 0x46, 0x46, 0x45, 0x4a, 0x43, 0x46, 0x49, 0x43, 0x52, + 0x3e, 0x48, 0x4a, 0x4b, 0x47, 0x47, 0x48, 0x4a, 0x4b, 0x4b, 0x4e, 0x44, + 0x42, 0x44, 0x50, 0x41, 0x49, 0x49, 0x4d, 0x4b, 0x44, 0x46, 0x4a, 0x52, + 0x4d, 0x47, 0x49, 0x4b, 0x4d, 0x49, 0x41, 0x48, 0x4b, 0x3f, 0x45, 0x4f, + 0x51, 0x41, 0x55, 0x42, 0x49, 0x4b, 0x4b, 0x51, 0x4f, 0x4f, 0x42, 0x4e, + 0x4e, 0x4a, 0x52, 0x41, 0x4f, 0x42, 0x48, 0x3d, 0x4a, 0x44, 0x50, 0x4b, + 0x49, 0x45, 0x51, 0x46, 0x51, 0x44, 0x4d, 0x47, 0x4a, 0x4a, 0x4d, 0x49, + 0x4d, 0x48, 0x4d, 0x4f, 0x4d, 0x44, 0x48, 0x4e, 0x4a, 0x4b, 0x40, 0x4f, + 0x47, 0x3a, 0x41, 0x47, 0x4a, 0x4a, 0x4a, 0x48, 0x42, 0x41, 0x4d, 0x56, + 0x3f, 0x52, 0x4d, 0x4c, 0x44, 0x48, 0x47, 0x4e, 0x51, 0x4c, 0x49, 0x47, + 0x44, 0x4c, 0x4b, 0x47, 0x48, 0x46, 0x47, 0x4f, 0x43, 0x41, 0x3e, 0x47, + 0x53, 0x4a, 0x46, 0x42, 0x46, 0x61, 0x43, 0x30, 0x4e, 0x52, 0x43, 0x45, + 0x32, 0x4a, 0x45, 0x48, 0x51, 0x3e, 0x44, 0x3b, 0x3a, 0x63, 0x4c, 0x46, + 0x4c, 0x49, 0x3d, 0x41, 0x52, 0x53, 0x43, 0x43, 0x45, 0x3d, 0x48, 0x40, + 0x4b, 0x4a, 0x49, 0x48, 0x4d, 0x49, 0x4b, 0x4c, 0x3f, 0x4e, 0x4b, 0x47, + 0x45, 0x4d, 0x3f, 0x4d, 0x43, 0x50, 0x48, 0x4b, 0x54, 0x3e, 0x44, 0x4e, + 0x3e, 0x4c, 0x43, 0x4b, 0x4c, 0x4b, 0x3e, 0x49, 0x50, 0x52, 0x4a, 0x4a, + 0x50, 0x50, 0x43, 0x4e, 0x49, 0x48, 0x51, 0x50, 0x47, 0x3d, 0x45, 0x4b, + 0x47, 0x46, 0x4d, 0x4c, 0x45, 0x4d, 0x4a, 0x4d, 0x42, 0x4d, 0x47, 0x4f, + 0x40, 0x43, 0x46, 0x51, 0x47, 0x4b, 0x43, 0x49, 0x49, 0x50, 0x4b, 0x4b, + 0x46, 0x4a, 0x4c, 0x48, 0x49, 0x47, 0x4b, 0x56, 0x55, 0x4f, 0x49, 0x4f, + 0x4f, 0x4e, 0x4b, 0x49, 0x4a, 0x4a, 0x49, 0x47, 0x44, 0x4b, 0x47, 0x50, + 0x46, 0x4c, 0x46, 0x4c, 0x4b, 0x4e, 0x49, 0x57, 0x4d, 0x3e, 0x46, 0x47, + 0x50, 0x45, 0x4f, 0x52, 0x3e, 0x4d, 0x49, 0x4a, 0x40, 0x49, 0x4f, 0x5c, + 0x3e, 0x4a, 0x47, 0x45, 0x47, 0x41, 0x44, 0x3f, 0x4b, 0x4a, 0x52, 0x43, + 0x41, 0x43, 0x43, 0x47, 0x55, 0x49, 0x42, 0x4c, 0x58, 0x4b, 0x42, 0x48, + 0x4b, 0x5a, 0x36, 0x33, 0x53, 0x57, 0x4d, 0x4a, 0x37, 0x4c, 0x3e, 0x48, + 0x43, 0x46, 0x39, 0x3c, 0x34, 0x65, 0x47, 0x3d, 0x47, 0x42, 0x3c, 0x3e, + 0x45, 0x5b, 0x44, 0x3e, 0x45, 0x43, 0x46, 0x43, 0x59, 0x4e, 0x48, 0x46, + 0x43, 0x3f, 0x46, 0x47, 0x4e, 0x53, 0x50, 0x4b, 0x4a, 0x3f, 0x4a, 0x54, + 0x4c, 0x4a, 0x43, 0x50, 0x4c, 0x42, 0x4d, 0x55, 0x4d, 0x51, 0x51, 0x46, + 0x49, 0x41, 0x50, 0x44, 0x4a, 0x4b, 0x4b, 0x43, 0x4b, 0x4e, 0x47, 0x4b, + 0x3e, 0x4e, 0x44, 0x4d, 0x49, 0x41, 0x49, 0x44, 0x50, 0x4d, 0x45, 0x4e, + 0x4b, 0x50, 0x45, 0x4c, 0x46, 0x4a, 0x46, 0x42, 0x50, 0x45, 0x48, 0x53, + 0x4d, 0x44, 0x42, 0x50, 0x4c, 0x49, 0x45, 0x55, 0x4d, 0x42, 0x43, 0x41, + 0x4c, 0x41, 0x4e, 0x4d, 0x42, 0x4e, 0x3f, 0x44, 0x4d, 0x4c, 0x4b, 0x4a, + 0x47, 0x47, 0x4e, 0x54, 0x43, 0x40, 0x41, 0x55, 0x49, 0x49, 0x4e, 0x49, + 0x52, 0x4e, 0x46, 0x58, 0x4b, 0x3d, 0x4a, 0x44, 0x4e, 0x47, 0x53, 0x58, + 0x47, 0x42, 0x52, 0x46, 0x49, 0x4b, 0x47, 0x5a, 0x4c, 0x46, 0x46, 0x49, + 0x4b, 0x4d, 0x3d, 0x48, 0x40, 0x54, 0x48, 0x4c, 0x4c, 0x44, 0x4c, 0x46, + 0x47, 0x4b, 0x4d, 0x44, 0x5a, 0x4a, 0x3e, 0x46, 0x48, 0x53, 0x39, 0x30, + 0x51, 0x60, 0x4d, 0x47, 0x35, 0x4f, 0x45, 0x45, 0x4a, 0x4b, 0x42, 0x3f, + 0x38, 0x6c, 0x3d, 0x40, 0x44, 0x48, 0x3a, 0x3b, 0x46, 0x5e, 0x45, 0x3b, + 0x47, 0x47, 0x45, 0x42, 0x53, 0x55, 0x44, 0x45, 0x46, 0x43, 0x48, 0x48, + 0x52, 0x5d, 0x3e, 0x41, 0x53, 0x42, 0x48, 0x55, 0x49, 0x4d, 0x4a, 0x46, + 0x52, 0x46, 0x51, 0x48, 0x44, 0x46, 0x48, 0x41, 0x49, 0x49, 0x49, 0x49, + 0x41, 0x4d, 0x40, 0x4f, 0x45, 0x46, 0x45, 0x3f, 0x53, 0x40, 0x46, 0x43, + 0x47, 0x4d, 0x50, 0x4c, 0x55, 0x48, 0x45, 0x47, 0x4f, 0x46, 0x42, 0x4d, + 0x41, 0x48, 0x46, 0x4e, 0x42, 0x48, 0x48, 0x45, 0x41, 0x45, 0x48, 0x4a, + 0x40, 0x49, 0x43, 0x4b, 0x48, 0x4a, 0x4c, 0x45, 0x4b, 0x48, 0x48, 0x4f, + 0x40, 0x4b, 0x4a, 0x44, 0x50, 0x4a, 0x43, 0x50, 0x4c, 0x44, 0x46, 0x4c, + 0x42, 0x44, 0x4e, 0x55, 0x47, 0x49, 0x48, 0x47, 0x52, 0x4e, 0x44, 0x59, + 0x4e, 0x44, 0x4a, 0x48, 0x49, 0x4a, 0x42, 0x4e, 0x3e, 0x39, 0x51, 0x45, + 0x4d, 0x49, 0x4f, 0x54, 0x51, 0x4b, 0x50, 0x44, 0x53, 0x4f, 0x4d, 0x48, + 0x42, 0x45, 0x4e, 0x40, 0x4a, 0x48, 0x43, 0x48, 0x52, 0x54, 0x4d, 0x49, + 0x5f, 0x53, 0x46, 0x4e, 0x3f, 0x5a, 0x36, 0x31, 0x52, 0x60, 0x4b, 0x4a, + 0x32, 0x51, 0x40, 0x44, 0x46, 0x52, 0x44, 0x41, 0x3a, 0x6e, 0x41, 0x3e, + 0x47, 0x3e, 0x3a, 0x2a, 0x44, 0x5a, 0x40, 0x3c, 0x4d, 0x48, 0x46, 0x3b, + 0x5e, 0x58, 0x4d, 0x47, 0x51, 0x3a, 0x4b, 0x48, 0x5b, 0x5a, 0x54, 0x43, + 0x50, 0x4c, 0x54, 0x54, 0x49, 0x47, 0x4f, 0x48, 0x50, 0x40, 0x4f, 0x4a, + 0x42, 0x42, 0x3c, 0x41, 0x43, 0x4e, 0x53, 0x49, 0x4b, 0x4d, 0x49, 0x41, + 0x4c, 0x3e, 0x40, 0x49, 0x40, 0x44, 0x49, 0x4f, 0x50, 0x4a, 0x42, 0x3a, + 0x49, 0x4b, 0x47, 0x50, 0x49, 0x41, 0x52, 0x46, 0x3d, 0x44, 0x46, 0x43, + 0x4b, 0x4b, 0x4d, 0x4b, 0x4e, 0x40, 0x45, 0x43, 0x48, 0x44, 0x55, 0x51, + 0x4a, 0x46, 0x4e, 0x40, 0x53, 0x4a, 0x45, 0x41, 0x48, 0x48, 0x45, 0x4e, + 0x4a, 0x48, 0x40, 0x4c, 0x54, 0x44, 0x42, 0x4d, 0x49, 0x43, 0x45, 0x4c, + 0x43, 0x4f, 0x46, 0x3f, 0x46, 0x4f, 0x4b, 0x59, 0x46, 0x49, 0x54, 0x47, + 0x49, 0x46, 0x45, 0x53, 0x4a, 0x49, 0x54, 0x45, 0x41, 0x45, 0x4c, 0x5e, + 0x50, 0x3d, 0x4d, 0x49, 0x55, 0x4b, 0x49, 0x47, 0x4c, 0x4f, 0x43, 0x3d, + 0x41, 0x4b, 0x43, 0x46, 0x4f, 0x4a, 0x4c, 0x54, 0x5e, 0x4e, 0x40, 0x4d, + 0x3d, 0x59, 0x40, 0x28, 0x54, 0x5f, 0x4d, 0x4b, 0x36, 0x51, 0x3a, 0x47, + 0x4a, 0x55, 0x42, 0x43, 0x3b, 0x72, 0x3b, 0x3d, 0x51, 0x42, 0x3f, 0x2d, + 0x4b, 0x5a, 0x48, 0x44, 0x49, 0x49, 0x3d, 0x39, 0x56, 0x55, 0x46, 0x46, + 0x4b, 0x43, 0x40, 0x4a, 0x52, 0x56, 0x4d, 0x45, 0x4b, 0x48, 0x40, 0x5a, + 0x4e, 0x3a, 0x53, 0x48, 0x4c, 0x44, 0x49, 0x4e, 0x42, 0x47, 0x46, 0x40, + 0x51, 0x42, 0x50, 0x4b, 0x43, 0x53, 0x44, 0x44, 0x46, 0x4c, 0x4c, 0x3c, + 0x42, 0x45, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x3d, 0x47, 0x4b, 0x4c, 0x4e, + 0x52, 0x4a, 0x4e, 0x41, 0x3f, 0x46, 0x43, 0x54, 0x44, 0x53, 0x4e, 0x48, + 0x40, 0x41, 0x4f, 0x45, 0x43, 0x3c, 0x52, 0x49, 0x40, 0x44, 0x4a, 0x3f, + 0x4d, 0x4c, 0x4f, 0x47, 0x44, 0x47, 0x55, 0x47, 0x50, 0x4d, 0x4a, 0x4c, + 0x50, 0x48, 0x47, 0x55, 0x4b, 0x4a, 0x52, 0x49, 0x3d, 0x3f, 0x4f, 0x51, + 0x48, 0x4e, 0x42, 0x4e, 0x42, 0x48, 0x4e, 0x49, 0x4a, 0x50, 0x45, 0x54, + 0x41, 0x43, 0x45, 0x4d, 0x48, 0x48, 0x48, 0x51, 0x53, 0x3e, 0x55, 0x44, + 0x52, 0x56, 0x44, 0x4d, 0x4e, 0x48, 0x4b, 0x43, 0x48, 0x53, 0x48, 0x44, + 0x49, 0x45, 0x4e, 0x50, 0x5d, 0x4a, 0x45, 0x4c, 0x45, 0x55, 0x43, 0x2e, + 0x59, 0x60, 0x4e, 0x4d, 0x32, 0x53, 0x3e, 0x3f, 0x40, 0x63, 0x41, 0x48, + 0x38, 0x73, 0x38, 0x46, 0x50, 0x3e, 0x3c, 0x23, 0x48, 0x61, 0x45, 0x3c, + 0x41, 0x41, 0x36, 0x3b, 0x58, 0x56, 0x4a, 0x40, 0x4f, 0x44, 0x45, 0x4c, + 0x5a, 0x56, 0x47, 0x3f, 0x4d, 0x4b, 0x46, 0x5d, 0x52, 0x47, 0x45, 0x4c, + 0x4a, 0x52, 0x4f, 0x4f, 0x4f, 0x43, 0x4f, 0x47, 0x43, 0x46, 0x3c, 0x4c, + 0x46, 0x55, 0x40, 0x53, 0x43, 0x3e, 0x42, 0x35, 0x51, 0x41, 0x42, 0x3f, + 0x45, 0x3d, 0x41, 0x31, 0x4e, 0x47, 0x48, 0x42, 0x41, 0x45, 0x43, 0x38, + 0x42, 0x40, 0x4a, 0x47, 0x4e, 0x43, 0x40, 0x43, 0x48, 0x49, 0x45, 0x4f, + 0x44, 0x42, 0x4d, 0x42, 0x42, 0x3f, 0x46, 0x52, 0x3c, 0x3c, 0x47, 0x43, + 0x46, 0x47, 0x45, 0x40, 0x4c, 0x44, 0x43, 0x4a, 0x4b, 0x4d, 0x4e, 0x46, + 0x51, 0x45, 0x47, 0x4b, 0x45, 0x50, 0x40, 0x42, 0x4c, 0x4c, 0x4c, 0x4f, + 0x44, 0x3c, 0x49, 0x3c, 0x3f, 0x45, 0x3f, 0x5c, 0x42, 0x3e, 0x4b, 0x4e, + 0x50, 0x45, 0x42, 0x5c, 0x4c, 0x48, 0x50, 0x52, 0x50, 0x47, 0x4b, 0x44, + 0x3d, 0x50, 0x55, 0x4c, 0x48, 0x3f, 0x4b, 0x44, 0x4a, 0x51, 0x42, 0x4c, + 0x60, 0x51, 0x41, 0x4b, 0x46, 0x5c, 0x42, 0x2c, 0x55, 0x61, 0x50, 0x52, + 0x37, 0x5a, 0x3f, 0x43, 0x43, 0x58, 0x3a, 0x4d, 0x3e, 0x72, 0x35, 0x3f, + 0x58, 0x41, 0x40, 0x1f, 0x55, 0x63, 0x3f, 0x49, 0x41, 0x3e, 0x35, 0x41, + 0x65, 0x54, 0x42, 0x45, 0x45, 0x3c, 0x44, 0x45, 0x59, 0x5a, 0x4d, 0x41, + 0x51, 0x46, 0x49, 0x59, 0x4c, 0x41, 0x42, 0x44, 0x4a, 0x45, 0x3f, 0x4a, + 0x4a, 0x44, 0x48, 0x48, 0x52, 0x40, 0x4a, 0x4a, 0x4d, 0x54, 0x44, 0x48, + 0x54, 0x46, 0x49, 0x3b, 0x42, 0x4a, 0x4e, 0x46, 0x4a, 0x45, 0x4f, 0x30, + 0x46, 0x41, 0x47, 0x46, 0x4b, 0x47, 0x46, 0x38, 0x4c, 0x3a, 0x4b, 0x46, + 0x52, 0x48, 0x4f, 0x3e, 0x48, 0x4a, 0x48, 0x4b, 0x44, 0x45, 0x4a, 0x46, + 0x3f, 0x4f, 0x40, 0x44, 0x43, 0x43, 0x4b, 0x39, 0x46, 0x43, 0x49, 0x49, + 0x49, 0x4a, 0x44, 0x48, 0x4c, 0x41, 0x4d, 0x52, 0x4c, 0x4a, 0x46, 0x3d, + 0x41, 0x4b, 0x41, 0x48, 0x45, 0x3b, 0x51, 0x54, 0x4a, 0x39, 0x4d, 0x41, + 0x54, 0x46, 0x4c, 0x53, 0x48, 0x3e, 0x4a, 0x3d, 0x41, 0x52, 0x54, 0x63, + 0x44, 0x4d, 0x4a, 0x43, 0x52, 0x4b, 0x52, 0x52, 0x4e, 0x41, 0x48, 0x42, + 0x48, 0x4d, 0x49, 0x45, 0x51, 0x48, 0x3e, 0x47, 0x5a, 0x52, 0x4a, 0x4e, + 0x3e, 0x59, 0x3c, 0x2e, 0x5c, 0x5b, 0x4c, 0x56, 0x30, 0x59, 0x3a, 0x48, + 0x3d, 0x5c, 0x44, 0x49, 0x40, 0x7c, 0x3a, 0x48, 0x54, 0x40, 0x41, 0x28, + 0x4d, 0x64, 0x46, 0x47, 0x49, 0x40, 0x30, 0x3a, 0x5f, 0x5b, 0x42, 0x37, + 0x49, 0x45, 0x40, 0x43, 0x5b, 0x54, 0x48, 0x4d, 0x4a, 0x47, 0x51, 0x58, + 0x4b, 0x3c, 0x4d, 0x46, 0x4b, 0x52, 0x4c, 0x58, 0x53, 0x46, 0x42, 0x45, + 0x4c, 0x4a, 0x4d, 0x4e, 0x52, 0x4d, 0x46, 0x44, 0x46, 0x3f, 0x46, 0x34, + 0x4f, 0x42, 0x44, 0x46, 0x44, 0x50, 0x47, 0x30, 0x44, 0x3c, 0x42, 0x46, + 0x4f, 0x4a, 0x52, 0x30, 0x55, 0x4f, 0x45, 0x4a, 0x48, 0x4c, 0x4e, 0x35, + 0x4e, 0x3c, 0x45, 0x4a, 0x45, 0x4a, 0x44, 0x3c, 0x4e, 0x4a, 0x51, 0x44, + 0x49, 0x40, 0x4a, 0x40, 0x41, 0x44, 0x4f, 0x4c, 0x43, 0x45, 0x4b, 0x43, + 0x3e, 0x3e, 0x4c, 0x44, 0x48, 0x48, 0x42, 0x42, 0x4d, 0x43, 0x50, 0x4d, + 0x49, 0x3c, 0x45, 0x4f, 0x4c, 0x46, 0x4b, 0x48, 0x4d, 0x4d, 0x49, 0x55, + 0x49, 0x3b, 0x40, 0x44, 0x4a, 0x4b, 0x4e, 0x5e, 0x43, 0x47, 0x45, 0x43, + 0x4d, 0x4d, 0x49, 0x46, 0x4a, 0x44, 0x4e, 0x3e, 0x52, 0x41, 0x47, 0x47, + 0x4a, 0x50, 0x48, 0x43, 0x5d, 0x4f, 0x49, 0x48, 0x43, 0x4f, 0x45, 0x3e, + 0x5a, 0x69, 0x4d, 0x5a, 0x3a, 0x5d, 0x3a, 0x48, 0x42, 0x55, 0x3e, 0x48, + 0x48, 0x7b, 0x37, 0x40, 0x57, 0x45, 0x48, 0x24, 0x50, 0x61, 0x4c, 0x4a, + 0x44, 0x41, 0x34, 0x38, 0x65, 0x5b, 0x4f, 0x3c, 0x4d, 0x3a, 0x4a, 0x4c, + 0x66, 0x55, 0x50, 0x47, 0x4d, 0x46, 0x47, 0x58, 0x4c, 0x48, 0x48, 0x48, + 0x4e, 0x59, 0x4f, 0x4b, 0x45, 0x45, 0x4b, 0x54, 0x46, 0x51, 0x4f, 0x44, + 0x42, 0x55, 0x48, 0x44, 0x48, 0x41, 0x53, 0x2e, 0x4d, 0x45, 0x44, 0x54, + 0x4a, 0x44, 0x53, 0x34, 0x4c, 0x46, 0x47, 0x3f, 0x4c, 0x4b, 0x47, 0x36, + 0x47, 0x41, 0x43, 0x40, 0x51, 0x46, 0x45, 0x33, 0x46, 0x3e, 0x47, 0x50, + 0x3f, 0x48, 0x48, 0x37, 0x41, 0x41, 0x42, 0x3e, 0x45, 0x3d, 0x49, 0x3e, + 0x4f, 0x42, 0x49, 0x4a, 0x46, 0x46, 0x48, 0x44, 0x49, 0x45, 0x46, 0x4a, + 0x4a, 0x47, 0x48, 0x43, 0x44, 0x45, 0x3f, 0x4c, 0x4c, 0x49, 0x4d, 0x51, + 0x4a, 0x4a, 0x49, 0x4c, 0x42, 0x4d, 0x4b, 0x4b, 0x4a, 0x42, 0x47, 0x4d, + 0x3e, 0x4b, 0x47, 0x5c, 0x49, 0x3d, 0x4e, 0x41, 0x44, 0x49, 0x3e, 0x3e, + 0x4b, 0x47, 0x4e, 0x45, 0x44, 0x4a, 0x4d, 0x4a, 0x4f, 0x46, 0x45, 0x52, + 0x60, 0x53, 0x49, 0x50, 0x3d, 0x4f, 0x43, 0x3d, 0x52, 0x64, 0x52, 0x58, + 0x39, 0x5f, 0x36, 0x4c, 0x45, 0x57, 0x42, 0x4b, 0x3f, 0x80, 0x34, 0x47, + 0x58, 0x41, 0x45, 0x1b, 0x4b, 0x5e, 0x4c, 0x40, 0x44, 0x42, 0x39, 0x3a, + 0x5e, 0x5b, 0x4b, 0x3a, 0x4b, 0x3f, 0x45, 0x3e, 0x69, 0x57, 0x4b, 0x45, + 0x4b, 0x3f, 0x45, 0x55, 0x49, 0x49, 0x48, 0x47, 0x41, 0x4f, 0x42, 0x53, + 0x49, 0x40, 0x42, 0x3e, 0x49, 0x47, 0x53, 0x47, 0x45, 0x51, 0x4a, 0x44, + 0x44, 0x45, 0x4e, 0x2a, 0x45, 0x42, 0x4a, 0x4b, 0x46, 0x4d, 0x41, 0x30, + 0x3d, 0x43, 0x3f, 0x48, 0x49, 0x44, 0x4d, 0x2e, 0x48, 0x4a, 0x4c, 0x51, + 0x50, 0x46, 0x3e, 0x2c, 0x4d, 0x3f, 0x47, 0x46, 0x3c, 0x40, 0x4c, 0x38, + 0x4f, 0x46, 0x47, 0x53, 0x3b, 0x3c, 0x4e, 0x3e, 0x49, 0x40, 0x43, 0x4c, + 0x4d, 0x48, 0x45, 0x3c, 0x4d, 0x4c, 0x4d, 0x45, 0x3f, 0x49, 0x4a, 0x43, + 0x4d, 0x41, 0x4b, 0x50, 0x4e, 0x46, 0x50, 0x44, 0x49, 0x44, 0x4e, 0x42, + 0x4a, 0x43, 0x4c, 0x4c, 0x49, 0x49, 0x44, 0x4e, 0x4b, 0x3f, 0x4b, 0x5d, + 0x41, 0x49, 0x4b, 0x46, 0x4e, 0x48, 0x45, 0x51, 0x4d, 0x45, 0x46, 0x45, + 0x4b, 0x4e, 0x3c, 0x4d, 0x3d, 0x41, 0x47, 0x47, 0x64, 0x54, 0x41, 0x55, + 0x47, 0x56, 0x44, 0x3b, 0x53, 0x66, 0x4f, 0x5e, 0x40, 0x5d, 0x38, 0x4a, + 0x41, 0x59, 0x42, 0x48, 0x47, 0xff, 0x36, 0x49, 0x59, 0x41, 0x43, 0x1d, + 0x4d, 0x5e, 0x44, 0x44, 0x50, 0x3f, 0x39, 0x40, 0x68, 0x5e, 0x4a, 0x41, + 0x52, 0x41, 0x43, 0x41, 0x68, 0x51, 0x45, 0x48, 0x4c, 0x46, 0x4a, 0x5e, + 0x4e, 0x40, 0x4d, 0x41, 0x41, 0x5c, 0x3f, 0x4e, 0x4c, 0x37, 0x48, 0x40, + 0x46, 0x47, 0x4f, 0x43, 0x53, 0x52, 0x3d, 0x44, 0x47, 0x44, 0x3d, 0x34, + 0x44, 0x42, 0x4a, 0x43, 0x4d, 0x3f, 0x53, 0x2e, 0x42, 0x47, 0x43, 0x4d, + 0x45, 0x45, 0x47, 0x31, 0x4d, 0x39, 0x41, 0x4a, 0x4a, 0x4d, 0x4b, 0x35, + 0x47, 0x4e, 0x4c, 0x40, 0x4a, 0x44, 0x44, 0x36, 0x3e, 0x49, 0x3f, 0x45, + 0x46, 0x43, 0x4e, 0x3c, 0x4d, 0x47, 0x4c, 0x48, 0x4a, 0x4b, 0x48, 0x39, + 0x46, 0x50, 0x4a, 0x4f, 0x46, 0x41, 0x44, 0x4a, 0x41, 0x4f, 0x4c, 0x4e, + 0x55, 0x46, 0x43, 0x46, 0x4a, 0x48, 0x4e, 0x46, 0x42, 0x40, 0x4f, 0x56, + 0x4c, 0x45, 0x4b, 0x46, 0x4a, 0x47, 0x42, 0x5e, 0x49, 0x4e, 0x46, 0x43, + 0x4e, 0x42, 0x45, 0x48, 0x47, 0x48, 0x4f, 0x45, 0x47, 0x51, 0x4b, 0x4c, + 0x51, 0x39, 0x4d, 0x48, 0x60, 0x57, 0x49, 0x52, 0x3d, 0x57, 0x46, 0x3d, + 0x53, 0x68, 0x4b, 0x60, 0x40, 0x5a, 0x41, 0x4b, 0x46, 0x56, 0x46, 0x4c, + 0x49, 0x7e, 0x2f, 0x48, 0x51, 0x42, 0x40, 0x20, 0x4b, 0x62, 0x4d, 0x41, + 0x4f, 0x43, 0x3d, 0x35, 0x63, 0x63, 0x46, 0x3e, 0x4e, 0x47, 0x40, 0x40, + 0x60, 0x52, 0x4c, 0x46, 0x49, 0x48, 0x4f, 0x56, 0x51, 0x47, 0x52, 0x4e, + 0x4b, 0x59, 0x55, 0x4f, 0x48, 0x3d, 0x48, 0x4a, 0x4d, 0x50, 0x47, 0x47, + 0x51, 0x52, 0x4d, 0x51, 0x45, 0x45, 0x47, 0x2d, 0x4d, 0x41, 0x43, 0x49, + 0x4d, 0x40, 0x4a, 0x2f, 0x4f, 0x43, 0x46, 0x4a, 0x3e, 0x4a, 0x4a, 0x2b, + 0x49, 0x4c, 0x4c, 0x3e, 0x41, 0x4c, 0x4a, 0x2b, 0x40, 0x44, 0x46, 0x4a, + 0x40, 0x44, 0x42, 0x38, 0x52, 0x42, 0x46, 0x51, 0x53, 0x4e, 0x45, 0x31, + 0x45, 0x47, 0x4f, 0x46, 0x49, 0x43, 0x45, 0x3b, 0x4b, 0x4b, 0x4b, 0x4c, + 0x43, 0x4a, 0x4c, 0x43, 0x4e, 0x40, 0x52, 0x44, 0x48, 0x49, 0x47, 0x4b, + 0x4e, 0x3d, 0x4e, 0x44, 0x48, 0x4d, 0x4f, 0x4f, 0x50, 0x36, 0x47, 0x41, + 0x4a, 0x44, 0x45, 0x56, 0x4f, 0x4c, 0x50, 0x4b, 0x45, 0x3e, 0x45, 0x4e, + 0x45, 0x45, 0x43, 0x40, 0x47, 0x4e, 0x45, 0x3e, 0x4a, 0x3f, 0x49, 0x50, + 0x62, 0x55, 0x48, 0x56, 0x3e, 0x57, 0x4f, 0x3b, 0x55, 0x6c, 0x50, 0x5c, + 0x3d, 0x54, 0x3d, 0x46, 0x43, 0x59, 0x3e, 0x51, 0x4d, 0x7b, 0x33, 0x47, + 0x52, 0x43, 0x3f, 0x25, 0x4a, 0x6f, 0x49, 0x3e, 0x50, 0x40, 0x41, 0x30, + 0x5e, 0x5c, 0x4a, 0x43, 0x4d, 0x42, 0x46, 0x3b, 0x63, 0x53, 0x4f, 0x43, + 0x58, 0x48, 0x4b, 0x59, 0x50, 0x4e, 0x4b, 0x51, 0x4a, 0x55, 0x44, 0x46, + 0x4c, 0x3d, 0x4c, 0x52, 0x44, 0x52, 0x4c, 0x41, 0x4f, 0x44, 0x4a, 0x47, + 0x4e, 0x48, 0x49, 0x2e, 0x3e, 0x45, 0x4c, 0x48, 0x41, 0x47, 0x4d, 0x2e, + 0x40, 0x4b, 0x4c, 0x42, 0x4d, 0x40, 0x4e, 0x2e, 0x43, 0x45, 0x4b, 0x43, + 0x3e, 0x49, 0x55, 0x35, 0x43, 0x42, 0x42, 0x40, 0x4e, 0x46, 0x44, 0x37, + 0x49, 0x41, 0x3f, 0x52, 0x47, 0x4b, 0x43, 0x33, 0x4b, 0x47, 0x4b, 0x4c, + 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x40, 0x49, 0x41, 0x42, 0x49, 0x4b, 0x46, + 0x4e, 0x4e, 0x47, 0x4e, 0x48, 0x48, 0x4b, 0x46, 0x51, 0x4b, 0x46, 0x4d, + 0x47, 0x4f, 0x3e, 0x51, 0x46, 0x4e, 0x46, 0x4b, 0x47, 0x48, 0x4e, 0x55, + 0x4c, 0x3d, 0x47, 0x51, 0x42, 0x45, 0x4f, 0x42, 0x52, 0x50, 0x44, 0x4c, + 0x44, 0x44, 0x43, 0x4d, 0x40, 0x42, 0x4d, 0x4b, 0x5d, 0x4e, 0x47, 0x54, + 0x47, 0x51, 0x43, 0x39, 0x58, 0x66, 0x4e, 0x5a, 0x41, 0x52, 0x36, 0x47, + 0x45, 0x5f, 0x34, 0x50, 0x46, 0x79, 0x30, 0x48, 0x50, 0x45, 0x32, 0x22, + 0x54, 0x64, 0x49, 0x46, 0x45, 0x3c, 0x42, 0x36, 0x65, 0x5c, 0x48, 0x3a, + 0x4d, 0x4b, 0x47, 0x3e, 0x63, 0x56, 0x4a, 0x48, 0x51, 0x42, 0x4f, 0x5e, + 0x4c, 0x44, 0x4b, 0x4c, 0x3d, 0x5a, 0x43, 0x4d, 0x42, 0x40, 0x4f, 0x4d, + 0x3f, 0x3e, 0x46, 0x40, 0x49, 0x42, 0x49, 0x40, 0x49, 0x4c, 0x4a, 0x2e, + 0x4b, 0x3f, 0x53, 0x4b, 0x48, 0x49, 0x3e, 0x34, 0x47, 0x4a, 0x4b, 0x46, + 0x3b, 0x49, 0x46, 0x34, 0x4b, 0x48, 0x4c, 0x49, 0x49, 0x43, 0x4f, 0x2e, + 0x44, 0x46, 0x48, 0x50, 0x46, 0x4e, 0x4a, 0x37, 0x4b, 0x4c, 0x4a, 0x50, + 0x45, 0x4a, 0x48, 0x3b, 0x48, 0x44, 0x48, 0x4a, 0x41, 0x44, 0x52, 0x3f, + 0x4c, 0x46, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x36, 0x53, 0x3e, 0x48, 0x47, + 0x3f, 0x42, 0x41, 0x4c, 0x42, 0x4a, 0x52, 0x46, 0x49, 0x3f, 0x48, 0x5a, + 0x43, 0x42, 0x3d, 0x43, 0x4f, 0x44, 0x43, 0x65, 0x41, 0x41, 0x44, 0x4b, + 0x50, 0x44, 0x53, 0x49, 0x41, 0x45, 0x4a, 0x4d, 0x40, 0x45, 0x4a, 0x4e, + 0x50, 0x40, 0x51, 0x40, 0x5e, 0x50, 0x43, 0x5c, 0x47, 0x5a, 0x44, 0x4c, + 0x54, 0x64, 0x4f, 0x63, 0x39, 0x58, 0x3c, 0x4a, 0x42, 0x5e, 0x3c, 0x4a, + 0x48, 0x7b, 0x34, 0x4c, 0x4f, 0x44, 0x30, 0x24, 0x50, 0x65, 0x47, 0x39, + 0x46, 0x3e, 0x3f, 0x33, 0x65, 0x5a, 0x44, 0x38, 0x50, 0x47, 0x4b, 0x3e, + 0x5b, 0x53, 0x4a, 0x4d, 0x51, 0x40, 0x47, 0x59, 0x51, 0x42, 0x4f, 0x50, + 0x45, 0x57, 0x46, 0x50, 0x3f, 0x3c, 0x4c, 0x4f, 0x46, 0x41, 0x4a, 0x3e, + 0x4d, 0x45, 0x51, 0x48, 0x4e, 0x44, 0x4e, 0x35, 0x44, 0x3f, 0x44, 0x48, + 0x3c, 0x4c, 0x49, 0x2c, 0x4a, 0x46, 0x48, 0x44, 0x4b, 0x42, 0x4b, 0x2f, + 0x4e, 0x50, 0x4c, 0x4d, 0x44, 0x46, 0x3f, 0x39, 0x4d, 0x47, 0x45, 0x41, + 0x42, 0x47, 0x4a, 0x3a, 0x40, 0x3e, 0x4a, 0x51, 0x3f, 0x47, 0x44, 0x37, + 0x47, 0x4e, 0x47, 0x52, 0x45, 0x42, 0x4a, 0x3d, 0x43, 0x4d, 0x4d, 0x47, + 0x48, 0x43, 0x44, 0x44, 0x47, 0x4e, 0x52, 0x4b, 0x4e, 0x50, 0x42, 0x47, + 0x4b, 0x4b, 0x4e, 0x4c, 0x4e, 0x47, 0x50, 0x56, 0x46, 0x47, 0x4d, 0x49, + 0x4d, 0x46, 0x49, 0x5f, 0x49, 0x42, 0x4d, 0x44, 0x40, 0x4b, 0x52, 0x45, + 0x46, 0x4a, 0x4b, 0x49, 0x47, 0x4b, 0x42, 0x45, 0x42, 0x44, 0x46, 0x4c, + 0x62, 0x4a, 0x44, 0x53, 0x43, 0x5a, 0x48, 0x49, 0x59, 0x68, 0x46, 0x61, + 0x40, 0x5a, 0x3a, 0x4d, 0x45, 0x5e, 0x33, 0x4f, 0x4e, 0x74, 0x3e, 0x3e, + 0x5a, 0x4b, 0x34, 0x31, 0x52, 0x6c, 0x44, 0x39, 0x4c, 0x3b, 0x39, 0x3a, + 0x63, 0x65, 0x4b, 0x40, 0x50, 0x4d, 0x53, 0x4a, 0x69, 0x56, 0x54, 0x45, + 0x4c, 0x4c, 0x50, 0x5b, 0x4d, 0x4f, 0x3d, 0x4b, 0x44, 0x47, 0x43, 0x47, + 0x49, 0x3c, 0x49, 0x41, 0x41, 0x3f, 0x47, 0x43, 0x48, 0x47, 0x4c, 0x43, + 0x4a, 0x40, 0x4d, 0x32, 0x4b, 0x4d, 0x44, 0x48, 0x46, 0x44, 0x50, 0x2f, + 0x4e, 0x49, 0x53, 0x4b, 0x52, 0x47, 0x4b, 0x2b, 0x48, 0x4b, 0x4a, 0x4c, + 0x4d, 0x4c, 0x43, 0x37, 0x48, 0x3c, 0x4b, 0x42, 0x51, 0x3f, 0x45, 0x3c, + 0x49, 0x40, 0x42, 0x43, 0x4d, 0x4c, 0x3f, 0x3f, 0x4d, 0x43, 0x45, 0x42, + 0x48, 0x42, 0x48, 0x39, 0x51, 0x4e, 0x46, 0x4f, 0x3e, 0x4c, 0x45, 0x3e, + 0x3f, 0x3f, 0x43, 0x41, 0x4b, 0x4b, 0x43, 0x4d, 0x44, 0x3b, 0x48, 0x45, + 0x3c, 0x4a, 0x48, 0x5b, 0x3c, 0x4b, 0x4c, 0x44, 0x46, 0x3e, 0x45, 0x57, + 0x43, 0x42, 0x51, 0x4a, 0x46, 0x47, 0x43, 0x49, 0x42, 0x43, 0x50, 0x4e, + 0x4e, 0x44, 0x41, 0x4e, 0x4e, 0x41, 0x48, 0x47, 0x5c, 0x53, 0x44, 0x54, + 0x44, 0x5b, 0x45, 0x46, 0x55, 0x67, 0x4d, 0x5d, 0x40, 0x5a, 0x43, 0x4b, + 0x43, 0x60, 0x3c, 0x4b, 0x41, 0x79, 0x41, 0x41, 0x58, 0x48, 0x40, 0x3b, + 0x4f, 0x6c, 0x46, 0x3f, 0x53, 0x3a, 0x3d, 0x36, 0x5a, 0x57, 0x44, 0x41, + 0x4c, 0x47, 0x4e, 0x48, 0x62, 0x60, 0x4a, 0x46, 0x51, 0x3e, 0x52, 0x5f, + 0x4b, 0x46, 0x48, 0x4c, 0x4c, 0x55, 0x43, 0x46, 0x49, 0x3e, 0x41, 0x40, + 0x4d, 0x47, 0x46, 0x3b, 0x51, 0x3a, 0x4a, 0x45, 0x50, 0x47, 0x51, 0x38, + 0x44, 0x41, 0x40, 0x4b, 0x4d, 0x44, 0x4d, 0x28, 0x47, 0x3e, 0x44, 0x40, + 0x49, 0x49, 0x40, 0x3c, 0x44, 0x4c, 0x48, 0x51, 0x46, 0x3e, 0x47, 0x2a, + 0x41, 0x44, 0x49, 0x4c, 0x4e, 0x4e, 0x42, 0x3c, 0x49, 0x42, 0x43, 0x45, + 0x4e, 0x4d, 0x50, 0x39, 0x42, 0x43, 0x48, 0x41, 0x3f, 0x40, 0x4e, 0x3a, + 0x44, 0x3d, 0x49, 0x4d, 0x47, 0x45, 0x4b, 0x42, 0x4c, 0x4d, 0x3f, 0x3f, + 0x4e, 0x4d, 0x4d, 0x4d, 0x4d, 0x45, 0x47, 0x43, 0x4c, 0x46, 0x47, 0x57, + 0x4b, 0x42, 0x4d, 0x46, 0x4b, 0x4b, 0x43, 0x58, 0x48, 0x49, 0x4d, 0x47, + 0x43, 0x49, 0x4b, 0x48, 0x46, 0x4f, 0x4f, 0x42, 0x4a, 0x43, 0x49, 0x4e, + 0x4a, 0x47, 0x4c, 0x48, 0x5a, 0x57, 0x4a, 0x58, 0x49, 0x4f, 0x45, 0x47, + 0x63, 0x66, 0x4d, 0x5e, 0x4b, 0x51, 0x45, 0x4a, 0x43, 0x5d, 0x33, 0x4b, + 0x4e, 0x70, 0x42, 0x39, 0x57, 0x4a, 0x40, 0x3a, 0x51, 0x68, 0x45, 0x45, + 0x4c, 0x44, 0x3a, 0x3a, 0x4f, 0x62, 0x49, 0x45, 0x53, 0x4c, 0x4e, 0x41, + 0x63, 0x5e, 0x44, 0x44, 0x47, 0x43, 0x47, 0x59, 0x4c, 0x4b, 0x4c, 0x49, + 0x3e, 0x43, 0x4c, 0x46, 0x4c, 0x38, 0x47, 0x46, 0x46, 0x47, 0x40, 0x44, + 0x51, 0x3e, 0x40, 0x47, 0x3f, 0x45, 0x48, 0x2a, 0x42, 0x3e, 0x43, 0x46, + 0x50, 0x4c, 0x4a, 0x2c, 0x49, 0x4b, 0x48, 0x48, 0x40, 0x4a, 0x4a, 0x37, + 0x4e, 0x42, 0x4f, 0x4c, 0x41, 0x43, 0x45, 0x38, 0x4e, 0x3d, 0x41, 0x47, + 0x42, 0x42, 0x43, 0x3b, 0x4a, 0x40, 0x48, 0x4a, 0x53, 0x44, 0x4d, 0x35, + 0x51, 0x3c, 0x4e, 0x4e, 0x3e, 0x3f, 0x4b, 0x3c, 0x3e, 0x47, 0x41, 0x48, + 0x40, 0x46, 0x4e, 0x44, 0x49, 0x42, 0x49, 0x44, 0x4b, 0x46, 0x46, 0x43, + 0x4c, 0x4b, 0x49, 0x4d, 0x3d, 0x47, 0x43, 0x5c, 0x4a, 0x42, 0x47, 0x4e, + 0x47, 0x40, 0x4c, 0x55, 0x3f, 0x45, 0x46, 0x49, 0x46, 0x48, 0x49, 0x4d, + 0x4c, 0x41, 0x49, 0x40, 0x4a, 0x44, 0x42, 0x49, 0x52, 0x41, 0x49, 0x4a, + 0x5c, 0x53, 0x47, 0x58, 0x49, 0x55, 0x4a, 0x4a, 0x62, 0x61, 0x4b, 0x57, + 0x3c, 0x50, 0x42, 0x4c, 0x49, 0x5f, 0x3f, 0x4a, 0x42, 0x70, 0x40, 0x40, + 0x4f, 0x46, 0x43, 0x43, 0x4d, 0x6c, 0x41, 0x3e, 0x4e, 0x49, 0x43, 0x38, + 0x50, 0x57, 0x43, 0x39, 0x4a, 0x4f, 0x51, 0x3e, 0x5c, 0x57, 0x46, 0x49, + 0x41, 0x40, 0x42, 0x4f, 0x4c, 0x45, 0x46, 0x4a, 0x4c, 0x4b, 0x43, 0x42, + 0x4c, 0x3c, 0x47, 0x47, 0x4f, 0x44, 0x45, 0x3a, 0x4d, 0x3d, 0x4d, 0x3f, + 0x46, 0x4f, 0x41, 0x37, 0x46, 0x45, 0x54, 0x47, 0x4e, 0x46, 0x47, 0x23, + 0x48, 0x4e, 0x4a, 0x47, 0x45, 0x45, 0x4e, 0x33, 0x49, 0x4a, 0x4d, 0x4e, + 0x49, 0x46, 0x49, 0x36, 0x48, 0x44, 0x53, 0x44, 0x4a, 0x45, 0x4a, 0x37, + 0x45, 0x36, 0x4b, 0x4e, 0x50, 0x3f, 0x49, 0x38, 0x40, 0x43, 0x46, 0x4c, + 0x43, 0x46, 0x4a, 0x3f, 0x45, 0x3d, 0x44, 0x47, 0x44, 0x42, 0x4a, 0x45, + 0x47, 0x43, 0x4d, 0x4d, 0x44, 0x44, 0x4f, 0x4a, 0x4a, 0x41, 0x50, 0x50, + 0x4b, 0x44, 0x54, 0x5c, 0x4b, 0x3a, 0x46, 0x4a, 0x4a, 0x43, 0x48, 0x5c, + 0x4b, 0x43, 0x47, 0x3d, 0x3e, 0x54, 0x42, 0x47, 0x42, 0x4f, 0x4b, 0x4b, + 0x46, 0x46, 0x46, 0x42, 0x42, 0x4b, 0x48, 0x45, 0x51, 0x4e, 0x49, 0x4d, + 0x43, 0x56, 0x45, 0x40, 0x5a, 0x58, 0x4c, 0x55, 0x40, 0x4b, 0x4c, 0x51, + 0x42, 0x59, 0x43, 0x46, 0x46, 0x69, 0x43, 0x3c, 0x54, 0x47, 0x3d, 0x41, + 0x52, 0x64, 0x44, 0x38, 0x4f, 0x49, 0x3a, 0x3a, 0x55, 0x54, 0x45, 0x3e, + 0x49, 0x44, 0x4e, 0x3f, 0x57, 0x50, 0x47, 0x43, 0x45, 0x48, 0x53, 0x5b, + 0x53, 0x4d, 0x48, 0x4e, 0x48, 0x3a, 0x3e, 0x46, 0x42, 0x36, 0x50, 0x4d, + 0x49, 0x4b, 0x4b, 0x45, 0x4c, 0x44, 0x50, 0x47, 0x3e, 0x49, 0x50, 0x37, + 0x4c, 0x4b, 0x4a, 0x54, 0x4e, 0x43, 0x40, 0x25, 0x46, 0x42, 0x52, 0x3d, + 0x44, 0x45, 0x51, 0x2e, 0x4a, 0x3d, 0x46, 0x46, 0x4c, 0x42, 0x48, 0x34, + 0x44, 0x44, 0x44, 0x4c, 0x4f, 0x4b, 0x42, 0x3d, 0x45, 0x40, 0x47, 0x49, + 0x43, 0x41, 0x3e, 0x39, 0x47, 0x4b, 0x50, 0x4a, 0x46, 0x47, 0x4e, 0x3b, + 0x4e, 0x3e, 0x49, 0x4a, 0x50, 0x40, 0x43, 0x49, 0x48, 0x3c, 0x4f, 0x45, + 0x4a, 0x41, 0x42, 0x48, 0x4b, 0x46, 0x4a, 0x50, 0x40, 0x49, 0x44, 0x54, + 0x45, 0x45, 0x4a, 0x4b, 0x51, 0x51, 0x48, 0x53, 0x50, 0x3f, 0x50, 0x46, + 0x44, 0x45, 0x51, 0x43, 0x4f, 0x3e, 0x41, 0x41, 0x46, 0x45, 0x45, 0x4c, + 0x54, 0x3c, 0x4a, 0x4c, 0x5a, 0x4f, 0x46, 0x4b, 0x47, 0x4a, 0x43, 0x4c, + 0x56, 0x5a, 0x4a, 0x53, 0x4c, 0x49, 0x46, 0x4c, 0x45, 0x59, 0x40, 0x4b, + 0x48, 0x60, 0x3d, 0x42, 0x52, 0x3f, 0x42, 0x3d, 0x52, 0x5f, 0x46, 0x42, + 0x4b, 0x4e, 0x4a, 0x3d, 0x52, 0x55, 0x53, 0x37, 0x47, 0x3e, 0x4a, 0x42, + 0x51, 0x54, 0x48, 0x48, 0x4b, 0x48, 0x3e, 0x52, 0x41, 0x4e, 0x4c, 0x4f, + 0x43, 0x3b, 0x4b, 0x4b, 0x4c, 0x40, 0x48, 0x49, 0x4d, 0x3a, 0x45, 0x3c, + 0x53, 0x44, 0x48, 0x4d, 0x4b, 0x49, 0x46, 0x3c, 0x4d, 0x40, 0x51, 0x3f, + 0x4c, 0x45, 0x44, 0x2f, 0x49, 0x51, 0x3f, 0x4d, 0x3e, 0x4e, 0x3c, 0x30, + 0x3d, 0x48, 0x4f, 0x3f, 0x45, 0x45, 0x46, 0x3b, 0x4c, 0x46, 0x4d, 0x50, + 0x4c, 0x3d, 0x41, 0x37, 0x3e, 0x3e, 0x4f, 0x4b, 0x4d, 0x4f, 0x45, 0x45, + 0x4a, 0x47, 0x4a, 0x44, 0x43, 0x46, 0x51, 0x41, 0x4e, 0x39, 0x44, 0x4a, + 0x4e, 0x49, 0x4a, 0x42, 0x49, 0x4b, 0x4e, 0x48, 0x49, 0x4a, 0x45, 0x4a, + 0x45, 0x41, 0x4a, 0x4b, 0x42, 0x41, 0x48, 0x4a, 0x44, 0x3a, 0x46, 0x49, + 0x54, 0x45, 0x44, 0x60, 0x4a, 0x4e, 0x45, 0x4a, 0x4a, 0x45, 0x4b, 0x49, + 0x42, 0x44, 0x46, 0x50, 0x4b, 0x4b, 0x4e, 0x45, 0x48, 0x3e, 0x55, 0x42, + 0x51, 0x49, 0x49, 0x44, 0x4e, 0x54, 0x53, 0x49, 0x4c, 0x63, 0x48, 0x5a, + 0x50, 0x4b, 0x45, 0x49, 0x43, 0x57, 0x4c, 0x3f, 0x4d, 0x67, 0x3f, 0x47, + 0x53, 0x49, 0x43, 0x44, 0x49, 0x61, 0x50, 0x47, 0x49, 0x49, 0x4a, 0x42, + 0x4a, 0x51, 0x46, 0x43, 0x3f, 0x34, 0x40, 0x3a, 0x45, 0x54, 0x4c, 0x55, + 0x40, 0x3c, 0x4a, 0x4d, 0x3e, 0x4d, 0x48, 0x51, 0x4c, 0x3e, 0x4c, 0x4f, + 0x50, 0x47, 0x4d, 0x49, 0x4d, 0x4e, 0x45, 0x43, 0x41, 0x41, 0x40, 0x47, + 0x43, 0x4a, 0x4a, 0x3c, 0x4c, 0x3d, 0x4e, 0x43, 0x41, 0x42, 0x4a, 0x30, + 0x45, 0x4c, 0x45, 0x55, 0x46, 0x39, 0x43, 0x39, 0x45, 0x47, 0x48, 0x53, + 0x4a, 0x48, 0x43, 0x38, 0x4f, 0x51, 0x4d, 0x4c, 0x41, 0x46, 0x40, 0x3d, + 0x43, 0x4b, 0x40, 0x46, 0x47, 0x50, 0x4a, 0x43, 0x50, 0x4e, 0x45, 0x4f, + 0x4d, 0x44, 0x4d, 0x3f, 0x4e, 0x48, 0x4a, 0x49, 0x44, 0x3d, 0x4a, 0x44, + 0x40, 0x45, 0x49, 0x40, 0x4a, 0x44, 0x4f, 0x4a, 0x43, 0x4a, 0x4e, 0x52, + 0x4d, 0x50, 0x48, 0x4c, 0x43, 0x45, 0x4d, 0x54, 0x4a, 0x49, 0x4c, 0x58, + 0x4c, 0x48, 0x4c, 0x44, 0x4b, 0x4e, 0x52, 0x44, 0x49, 0x44, 0x47, 0x4e, + 0x4b, 0x45, 0x49, 0x3e, 0x4c, 0x3b, 0x53, 0x3f, 0x51, 0x41, 0x3f, 0x44, + 0x43, 0x4a, 0x4b, 0x43, 0x53, 0x57, 0x50, 0x53, 0x4f, 0x4b, 0x48, 0x51, + 0x47, 0x49, 0x46, 0x4d, 0x4d, 0x5e, 0x44, 0x46, 0x56, 0x3d, 0x3c, 0x3e, + 0x47, 0x55, 0x54, 0x46, 0x42, 0x49, 0x4f, 0x43, 0x48, 0x54, 0x51, 0x40, + 0x44, 0x44, 0x47, 0x45, 0x4b, 0x59, 0x4d, 0x47, 0x40, 0x39, 0x48, 0x54, + 0x43, 0x45, 0x44, 0x42, 0x4c, 0x3c, 0x4d, 0x42, 0x4b, 0x45, 0x42, 0x48, + 0x51, 0x44, 0x45, 0x3f, 0x3d, 0x49, 0x4b, 0x4a, 0x41, 0x43, 0x4f, 0x3f, + 0x51, 0x4b, 0x44, 0x46, 0x46, 0x44, 0x53, 0x3d, 0x47, 0x47, 0x43, 0x4b, + 0x41, 0x43, 0x3c, 0x3b, 0x49, 0x47, 0x47, 0x49, 0x4b, 0x3d, 0x43, 0x43, + 0x4b, 0x47, 0x45, 0x4e, 0x42, 0x4a, 0x4c, 0x3e, 0x51, 0x3e, 0x46, 0x44, + 0x46, 0x43, 0x42, 0x42, 0x47, 0x4d, 0x51, 0x4b, 0x49, 0x44, 0x4d, 0x40, + 0x50, 0x43, 0x41, 0x4c, 0x42, 0x49, 0x49, 0x4c, 0x42, 0x50, 0x48, 0x3f, + 0x46, 0x42, 0x48, 0x57, 0x49, 0x4d, 0x47, 0x4e, 0x48, 0x4b, 0x46, 0x50, + 0x47, 0x45, 0x52, 0x45, 0x4b, 0x48, 0x40, 0x5b, 0x4e, 0x43, 0x51, 0x48, + 0x48, 0x4a, 0x4a, 0x4a, 0x52, 0x51, 0x4c, 0x4b, 0x42, 0x55, 0x4d, 0x46, + 0x50, 0x40, 0x4a, 0x50, 0x51, 0x3e, 0x42, 0x4c, 0x43, 0x46, 0x4d, 0x46, + 0x46, 0x4d, 0x4d, 0x52, 0x4e, 0x44, 0x45, 0x47, 0x49, 0x4c, 0x41, 0x44, + 0x4d, 0x54, 0x4c, 0x4a, 0x54, 0x3e, 0x44, 0x43, 0x53, 0x55, 0x4b, 0x4a, + 0x47, 0x47, 0x4f, 0x46, 0x4f, 0x4b, 0x51, 0x3f, 0x41, 0x4c, 0x43, 0x46, + 0x55, 0x51, 0x40, 0x4b, 0x4f, 0x40, 0x47, 0x50, 0x4e, 0x4a, 0x46, 0x4e, + 0x42, 0x4d, 0x48, 0x49, 0x48, 0x4a, 0x4a, 0x43, 0x49, 0x48, 0x44, 0x3b, + 0x51, 0x46, 0x3d, 0x43, 0x47, 0x4a, 0x4f, 0x42, 0x4a, 0x50, 0x4f, 0x41, + 0x45, 0x45, 0x43, 0x3c, 0x4c, 0x4c, 0x46, 0x4b, 0x3e, 0x44, 0x4b, 0x3a, + 0x45, 0x50, 0x42, 0x48, 0x46, 0x47, 0x44, 0x3a, 0x53, 0x46, 0x4e, 0x4f, + 0x43, 0x40, 0x46, 0x48, 0x4e, 0x45, 0x3f, 0x47, 0x48, 0x3f, 0x44, 0x4f, + 0x44, 0x47, 0x4e, 0x47, 0x47, 0x49, 0x42, 0x43, 0x3f, 0x49, 0x4a, 0x53, + 0x53, 0x4a, 0x4e, 0x4a, 0x49, 0x4d, 0x49, 0x41, 0x48, 0x4d, 0x4d, 0x4e, + 0x4b, 0x45, 0x4d, 0x4a, 0x46, 0x4a, 0x46, 0x51, 0x4b, 0x47, 0x49, 0x45, + 0x49, 0x49, 0x4b, 0x5c, 0x48, 0x42, 0x51, 0x4c, 0x41, 0x3f, 0x4c, 0x42, + 0x4f, 0x45, 0x4b, 0x4a, 0x52, 0x48, 0x53, 0x4f, 0x40, 0x47, 0x41, 0x47, + 0x68, 0xfb, 0xff, 0xff, 0x4c, 0xfc, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, + 0x58, 0x01, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00, + 0x38, 0x02, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0xa0, 0x01, 0x00, 0x00, + 0x14, 0x03, 0x00, 0x00, 0xfe, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x10, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x19, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, 0x00, 0x00, 0x00, 0x00, + 0xcc, 0xfc, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x17, 0xbf, 0xd2, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x58, 0xec, 0xd1, 0x43, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6e, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x43, 0x6f, 0x6e, 0x76, + 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x34, 0xff, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a, 0xc2, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, + 0x61, 0x70, 0x65, 0x5f, 0x31, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff, + 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x43, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0xfe, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x4d, 0x61, 0x74, 0x4d, + 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x0c, 0x00, 0x0c, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xc5, 0x01, 0x2a, 0x3b, 0x96, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x25, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, + 0x71, 0x75, 0x61, 0x6e, 0x74, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, + 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, + 0x78, 0x56, 0x61, 0x72, 0x73, 0x00, 0x00, 0x00, 0x84, 0xfe, 0xff, 0xff, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a, + 0x01, 0x00, 0x00, 0x00, 0x6e, 0x88, 0xae, 0x3d, 0x01, 0x00, 0x00, 0x00, + 0xd4, 0x97, 0x30, 0xbe, 0x26, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, + 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f, + 0x31, 0x00, 0x00, 0x00, 0xec, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x2f, 0xad, 0x18, 0x40, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x38, 0xa2, 0x43, 0x01, 0x00, 0x00, 0x00, 0x02, 0xf1, 0x8d, 0xc3, + 0x8e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x5f, 0x73, + 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, 0x5c, 0xff, 0xff, 0xff, + 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00, + 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x30, 0x11, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, + 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, + 0x74, 0x5f, 0x31, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, + 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, + 0x61, 0x72, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, + 0x65, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x31, 0x83, 0xce, 0x3a, 0x01, 0x00, 0x00, 0x00, + 0x4d, 0x97, 0x92, 0x3e, 0x01, 0x00, 0x00, 0x00, 0x84, 0x75, 0xec, 0xbd, + 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x14, 0x00, 0x1c, 0x00, + 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x18, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00, + 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00, + 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; +const int g_tiny_conv_model_data_len = 19800; diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h new file mode 100644 index 0000000000000000000000000000000000000000..2953cc852d98fa9b5551ae5036048de9c2ebf674 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a standard TensorFlow Lite model file that has been converted into a +// C data array, so it can be easily compiled into a binary for devices that +// don't have a file system. It was created using the command: +// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ + +extern const unsigned char g_tiny_conv_model_data[]; +extern const int g_tiny_conv_model_data_len; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a012f950e6f58f082d0a7c9ac0b4cd9018bcf40b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD @@ -0,0 +1,107 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl", + "tflite_micro_cc_test", +) + +cc_library( + name = "micro_ops", + srcs = [ + "depthwise_conv.cc", + "fully_connected.cc", + "softmax.cc", + ], + hdrs = [ + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels:padding", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "all_ops_resolver", + srcs = [ + "all_ops_resolver.cc", + ], + hdrs = [ + "all_ops_resolver.h", + ], + copts = tflite_copts(), + deps = [ + ":micro_ops", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + ], +) + +cc_library( + name = "test_utils", + srcs = [ + ], + hdrs = [ + "test_utils.h", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "depthwise_conv_test", + srcs = [ + "depthwise_conv_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "fully_connected_test", + srcs = [ + "fully_connected_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "softmax_test", + srcs = [ + "softmax_test.cc", + ], + deps = [ + ":all_ops_resolver", + ":test_utils", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + "//tensorflow/contrib/lite/experimental/micro/testing:micro_test", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd0a37badb8ab1e739fdee9c8be9c3f800e80e2e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" + +namespace tflite { +namespace ops { +namespace micro { + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Micro_Register_DEPTHWISE_CONV_2D() { + return Register_DEPTHWISE_CONV_2D(); +} + +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Micro_Register_FULLY_CONNECTED() { + return Register_FULLY_CONNECTED(); +} + +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Micro_Register_SOFTMAX() { return Register_SOFTMAX(); } + +AllOpsResolver::AllOpsResolver() { + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, + Micro_Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Micro_Register_FULLY_CONNECTED(), + /* min_version */ 1, + /* max_version */ 2); + AddBuiltin(BuiltinOperator_SOFTMAX, Micro_Register_SOFTMAX()); +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..f836064a3f63443ff577e7ac7a8b791cbb2c24c5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ + +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace micro { + +class AllOpsResolver : public MicroMutableOpResolver { + public: + AllOpsResolver(); + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace micro +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f17263181982afdaa1941194b88d58f0ef0ca74 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + int out_width, int out_height, + const TfLiteType data_type, OpData* data) { + data->padding.height = ComputePadding(params->stride_height, 1, height, + filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, 1, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); +} + +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int out_width = ComputeOutSize(params->padding, width, filter_width, + params->stride_width); + int out_height = ComputeOutSize(params->padding, height, filter_height, + params->stride_height); + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, out_width, + out_height, data_type, data)); + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, bias, output); + break; + default: + context->ReportError(context, "Type %d not currently supported.", + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {depthwise_conv::Init, depthwise_conv::Free, + depthwise_conv::Prepare, depthwise_conv::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..169899c471dd44399b4d8a479cecbbbd78ba1215 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestDepthwiseConvFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list filter_dims_data, + std::initializer_list filter_data, + std::initializer_list bias_dims_data, + std::initializer_list bias_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + TfLiteFusedActivation activation, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(filter_data, filter_dims, "filter_tensor"), + CreateFloatTensor(bias_data, bias_dims, "bias_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + int input_depth = input_dims->data[3]; + int output_depth = filter_dims->data[3]; + int depth_mul = output_depth / input_depth; + TfLiteDepthwiseConvParams builtin_data = { + kTfLitePaddingValid, 1, 1, depth_mul, activation, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestDepthwiseConvQuantized( + std::initializer_list input_dims_data, + std::initializer_list input_data, float input_min, float input_max, + std::initializer_list filter_dims_data, + std::initializer_list filter_data, float filter_min, + float filter_max, std::initializer_list bias_dims_data, + std::initializer_list bias_data, float bias_min, float bias_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, float output_min, + float output_max, TfLiteFusedActivation activation, uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(filter_data, filter_dims, "filter_tensor", + filter_min, filter_max), + CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min, + bias_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + int input_depth = input_dims->data[3]; + int output_depth = filter_dims->data[3]; + int depth_mul = output_depth / input_depth; + TfLiteDepthwiseConvParams builtin_data = { + kTfLitePaddingValid, 1, 1, depth_mul, activation, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 8; + float output_data[output_dims_count]; + tflite::testing::TestDepthwiseConvFloat( // + {4, 1, 3, 2, 2}, // Input shape. + { + 1, 2, 7, 8, // Input values. + 3, 4, 9, 10, // + 5, 6, 11, 12, // + }, + {4, 1, 2, 2, 4}, // Filters shape. + { + 1, 2, 3, 4, // Filters values. + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }, + {1, 4}, // Bias shape. + { + 1, 2, 3, 4, // Bias values. + }, + { + 71, -34, 99, -20, // Expected results. + 91, -26, 127, -4, // + }, + {4, 1, 2, 1, 4}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float filter_min = -63.5f; + const float filter_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 8; + uint8_t output_data[output_dims_count]; + + tflite::testing::TestDepthwiseConvQuantized( // + {4, 1, 3, 2, 2}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(7, input_min, input_max), + F2Q(8, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(9, input_min, input_max), + F2Q(10, input_min, input_max), + F2Q(5, input_min, input_max), + F2Q(6, input_min, input_max), + F2Q(11, input_min, input_max), + F2Q(12, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {4, 1, 2, 2, 4}, // Filter shape. + { + // Filter values. + F2Q(1, filter_min, filter_max), + F2Q(2, filter_min, filter_max), + F2Q(3, filter_min, filter_max), + F2Q(4, filter_min, filter_max), + F2Q(-9, filter_min, filter_max), + F2Q(10, filter_min, filter_max), + F2Q(-11, filter_min, filter_max), + F2Q(12, filter_min, filter_max), + F2Q(5, filter_min, filter_max), + F2Q(6, filter_min, filter_max), + F2Q(7, filter_min, filter_max), + F2Q(8, filter_min, filter_max), + F2Q(13, filter_min, filter_max), + F2Q(-14, filter_min, filter_max), + F2Q(15, filter_min, filter_max), + F2Q(-16, filter_min, filter_max), + }, + filter_min, filter_max, // Filter quantization range. + {1, 4}, // Bias shape. + { + // Bias values. + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + F2Q32(4, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(71, output_min, output_max), + F2Q(-34, output_min, output_max), + F2Q(99, output_min, output_max), + F2Q(-20, output_min, output_max), + F2Q(91, output_min, output_max), + F2Q(-26, output_min, output_max), + F2Q(127, output_min, output_max), + F2Q(-4, output_min, output_max), + }, + {4, 1, 2, 1, 4}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestRelu) { + const int output_dims_count = 8; + float output_data[output_dims_count]; + tflite::testing::TestDepthwiseConvFloat( // + {4, 1, 3, 2, 2}, // Input shape. + { + 1, 2, 7, 8, // Input values. + 3, 4, 9, 10, // + 5, 6, 11, 12, // + }, + {4, 1, 2, 2, 4}, // Filters shape. + { + 1, 2, 3, 4, // Filters values. + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }, + {1, 4}, // Bias shape. + { + 1, 2, 3, 4, // Bias values. + }, + { + 71, 0, 99, 0, // Expected results. + 91, 0, 127, 0, // + }, + {4, 1, 2, 1, 4}, // Output shape. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestReluQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float filter_min = -63.5f; + const float filter_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 8; + uint8_t output_data[output_dims_count]; + + tflite::testing::TestDepthwiseConvQuantized( // + {4, 1, 3, 2, 2}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), + F2Q(2, input_min, input_max), + F2Q(7, input_min, input_max), + F2Q(8, input_min, input_max), + F2Q(3, input_min, input_max), + F2Q(4, input_min, input_max), + F2Q(9, input_min, input_max), + F2Q(10, input_min, input_max), + F2Q(5, input_min, input_max), + F2Q(6, input_min, input_max), + F2Q(11, input_min, input_max), + F2Q(12, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {4, 1, 2, 2, 4}, // Filter shape. + { + // Filter values. + F2Q(1, filter_min, filter_max), + F2Q(2, filter_min, filter_max), + F2Q(3, filter_min, filter_max), + F2Q(4, filter_min, filter_max), + F2Q(-9, filter_min, filter_max), + F2Q(10, filter_min, filter_max), + F2Q(-11, filter_min, filter_max), + F2Q(12, filter_min, filter_max), + F2Q(5, filter_min, filter_max), + F2Q(6, filter_min, filter_max), + F2Q(7, filter_min, filter_max), + F2Q(8, filter_min, filter_max), + F2Q(13, filter_min, filter_max), + F2Q(-14, filter_min, filter_max), + F2Q(15, filter_min, filter_max), + F2Q(-16, filter_min, filter_max), + }, + filter_min, filter_max, // Filter quantization range. + {1, 4}, // Bias shape. + { + // Bias values. + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + F2Q32(4, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(71, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(99, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(91, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(127, output_min, output_max), + F2Q(0, output_min, output_max), + }, + {4, 1, 2, 1, 4}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e9e54cafb8c91af1b42d6d23396495ecad6e602 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace fully_connected { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFullyConnectedParams* params, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + TfLiteStatus status = kTfLiteOk; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + } + return status; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(filter), GetTensorData(filter), \ + GetTensorShape(bias), GetTensorData(bias), \ + GetTensorShape(output), GetTensorData(output), \ + nullptr) + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } + + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + tflite::reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(filter), GetTensorData(filter), + GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TfLiteType data_type = input->type; + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + + switch (filter->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, bias, + output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, filter, bias, + output); + + default: + context->ReportError(context, "Type %d not currently supported.", + filter->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b42bf4c3bca75572dbf8e1907e7fb94be24d41bd --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc @@ -0,0 +1,643 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestFullyConnectedFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list weights_dims_data, + std::initializer_list weights_data, + std::initializer_list bias_dims_data, + std::initializer_list bias_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + TfLiteFusedActivation activation, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(weights_data, weights_dims, "weights_tensor"), + CreateFloatTensor(bias_data, bias_dims, "bias_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteFullyConnectedParams builtin_data = { + activation, + kTfLiteFullyConnectedWeightsFormatDefault, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestFullyConnectedQuantized( + std::initializer_list input_dims_data, + std::initializer_list input_data, float input_min, float input_max, + std::initializer_list weights_dims_data, + std::initializer_list weights_data, float weights_min, + float weights_max, std::initializer_list bias_dims_data, + std::initializer_list bias_data, float bias_min, float bias_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, float output_min, + float output_max, TfLiteFusedActivation activation, uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(weights_data, weights_dims, "weights_tensor", + weights_min, weights_max), + CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min, + bias_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteFullyConnectedParams builtin_data = { + activation, + kTfLiteFullyConnectedWeightsFormatDefault, + }; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 10}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, 2, 3, // Bias values. + }, + { + 24, 25, 26, 58, 59, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest2) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 2}, // Input shape. + { + 1, 2, // b = 0 + 2, 1, // b = 1 + }, + {2, 1, 2}, // Weights shape. + { + 2, 4, // u = 0 + }, + {1, 1}, // Bias shape. + { + 1, // Bias values. + }, + { + 11, 9, // Expected results. + }, + {2, 2, 1}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestRelu) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {2, 2, 10}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, -2, 3, // Bias values. + }, + { + 24, 0, 26, 58, 0, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedRelu) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(-1, weights_min, weights_max), F2Q(-2, weights_min, weights_max), + F2Q(-3, weights_min, weights_max), F2Q(-4, weights_min, weights_max), + F2Q(-5, weights_min, weights_max), F2Q(-6, weights_min, weights_max), + F2Q(-7, weights_min, weights_max), F2Q(-8, weights_min, weights_max), + F2Q(-9, weights_min, weights_max), F2Q(-10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(0, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(0, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActRelu, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedOutputMultiplierGreaterThan1) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -127.0f; + const float input_max = 128.0f; + const float weights_min = -127.0f; + const float weights_max = 128.0f; + const float bias_min = 0.0f; + const float bias_max = 256.0f * (1 << 24); + const float output_min = -63.5f; + const float output_max = 64.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {2, 2, 10}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInput) { + const int output_dims_count = 6; + float output_data[output_dims_count]; + tflite::testing::TestFullyConnectedFloat( // + {4, 1, 1, 5, 1}, // Input shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }, + {2, 3, 10}, // Weights shape. + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }, + {1, 3}, // Bias shape. + { + 1, 2, 3, // Bias values. + }, + { + 24, 25, 26, 58, 59, 60, // Expected results. + }, + {2, 2, 3}, // Output shape. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInputQuantized) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float weights_min = -63.5f; + const float weights_max = 64.0f; + const float bias_min = 0.0f; + const float bias_max = 64.0f * (1 << 24); + const float output_min = -127.0f; + const float output_max = 128.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {4, 1, 1, 5, 1}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedOutputMultiplierGreaterThan1) { + using tflite::testing::F2Q; + using tflite::testing::F2Q32; + + const float input_min = -127.0f; + const float input_max = 128.0f; + const float weights_min = -127.0f; + const float weights_max = 128.0f; + const float bias_min = 0.0f; + const float bias_max = 256.0f * (1 << 24); + const float output_min = -63.5f; + const float output_max = 64.0f; + const int output_dims_count = 6; + uint8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized( // + {4, 1, 1, 5, 1}, // Input shape. + { + // Input values. + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(8, input_min, input_max), + F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max), + F2Q(1, input_min, input_max), F2Q(2, input_min, input_max), + F2Q(3, input_min, input_max), F2Q(4, input_min, input_max), + F2Q(5, input_min, input_max), F2Q(6, input_min, input_max), + F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max), + F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max), + }, + input_min, input_max, // Input quantization range. + {2, 3, 10}, // Weights shape. + { + // Weight values. + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max), + F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max), + F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max), + F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max), + F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max), + }, + weights_min, weights_max, // Weights quantization range. + {1, 3}, // Bias shape. + { + F2Q32(1, bias_min, bias_max), + F2Q32(2, bias_min, bias_max), + F2Q32(3, bias_min, bias_max), + }, + bias_min, bias_max, // Bias quantization range. + { + // Expected results. + F2Q(24, output_min, output_max), + F2Q(25, output_min, output_max), + F2Q(26, output_min, output_max), + F2Q(58, output_min, output_max), + F2Q(59, output_min, output_max), + F2Q(60, output_min, output_max), + }, + {2, 2, 3}, // Output shape. + output_min, output_max, // Output quantization range. + kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4019a067c563cac25d9918e4bdf75913bdfa3d6 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc @@ -0,0 +1,213 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace activations { +namespace { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, + const TfLiteTensor* input, + TfLiteTensor* output, + const TfLiteSoftmaxParams* params, + OpData* data) { + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + return kTfLiteOk; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Free(TfLiteContext* context, void* buffer) {} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +// Takes a 1D tensor and performs softmax along it. +void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int input_size = input->dims->data[0]; + tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta, + output->data.f); +} + +// Takes a 2D tensor and perform softmax along the last dimension. +void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + tflite::reference_ops::Softmax(input->data.f, input_size, batch_size, + params->beta, output->data.f); +} + +void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 1D + // tensor is 4D in a special way. We will convert a (Y) shape into a (1, + // 1, 1, Y) shape. + const int input_size = input->dims->data[0]; + const int32_t shape_data[4] = {1, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); +} + +void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int32_t shape_data[4] = {batch_size, 1, 1, input_size}; + RuntimeShape shape(4, shape_data); + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax(op_params, shape, + GetTensorData(input), shape, + GetTensorData(output)); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + SoftmaxParams op_params; + op_params.beta = params->beta; + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} + +void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + SoftmaxParams op_params; + op_params.input_multiplier = data->input_multiplier; + op_params.input_left_shift = data->input_left_shift; + op_params.diff_min = data->diff_min; + tflite::reference_ops::Softmax( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + OpData local_data_object; + OpData* data = &local_data_object; + TF_LITE_ENSURE_STATUS( + CalculateSoftmaxOpData(context, input, output, params, data)); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 1) { + Softmax1DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError( + context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 1) { + Softmax1DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError( + context, "Only 2D and 4D tensors supported currently, got %dD.", + NumDimensions(input)); + return kTfLiteError; + } + default: + context->ReportError( + context, "Only float32 and uint8_t supported currently, got %d.", + input->type); + return kTfLiteError; + } +} +} // namespace activations + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..694456d8ace5182578f9b59c2de8bbad0447b4ee --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestSoftmaxFloat(std::initializer_list input_dims_data, + std::initializer_list input_data, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims, "input_tensor"), + CreateFloatTensor(output_data, output_dims, "output_tensor"), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteSoftmaxParams builtin_data = {1.0f}; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i], + 1e-5f); + } +} + +void TestSoftmaxQuantized(std::initializer_list input_dims_data, + std::initializer_list input_data, + float input_min, float input_max, + std::initializer_list expected_output_data, + std::initializer_list output_dims_data, + float output_min, float output_max, + uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteSoftmaxParams builtin_data = {1.0f}; + const char* init_data = reinterpret_cast(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTest) { + const int output_dims_count = 10; + float output_data[output_dims_count]; + tflite::testing::TestSoftmaxFloat( // + {2, 2, 5}, // Input shape. + { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }, + { + // Expected results. + 0.011656231, + 0.031684921, + 0.086128544, + 0.234121657, + 0.636408647, + 0.636408647, + 0.234121657, + 0.086128544, + 0.031684921, + 0.011656231, + }, + {2, 2, 5}, // Output shape. + output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantized) { + using tflite::testing::F2Q; + + const float input_min = -63.5f; + const float input_max = 64.0f; + const float output_min = 0.0f; + const float output_max = (255.0f / 256.0f); + const int output_dims_count = 5; + uint8_t output_data[output_dims_count]; + tflite::testing::TestSoftmaxQuantized( // + {2, 1, 5}, // Input shape. + { + F2Q(1.0, input_min, input_max), + F2Q(2.0, input_min, input_max), + F2Q(3.0, input_min, input_max), + F2Q(4.0, input_min, input_max), + F2Q(5.0, input_min, input_max), + }, + input_min, input_max, // Input quantized range. + { + // Expected results. + F2Q(0.011656231, output_min, output_max), + F2Q(0.031684921, output_min, output_max), + F2Q(0.086128544, output_min, output_max), + F2Q(0.234121657, output_min, output_max), + F2Q(0.636408647, output_min, output_max), + }, + {2, 1, 5}, // Output shape. + output_min, output_max, // Output quantized range. + output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..789a48ece8bd68544649fb05548355cb796ccabb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h @@ -0,0 +1,170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h" +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { + +// How many elements are in the array with this shape. +inline int ElementCount(const TfLiteIntArray& dims) { + int result = 1; + for (int i = 0; i < dims.size; ++i) { + result *= dims.data[i]; + } + return result; +} + +// Wrapper to forward kernel errors to the interpreter's error reporter. +inline void ReportOpError(struct TfLiteContext* context, const char* format, + ...) { + ErrorReporter* error_reporter = static_cast(context->impl_); + va_list args; + va_start(args, format); + error_reporter->Report(format, args); + va_end(args); +} + +// Derives the quantization scaling factor from a min and max range. +template +inline float ScaleFromMinMax(const float min, const float max) { + return (max - min) / ((std::numeric_limits::max() * 1.0) - + std::numeric_limits::min()); +} + +// Derives the quantization zero point from a min and max range. +template +inline int ZeroPointFromMinMax(const float min, const float max) { + return static_cast((-min / ScaleFromMinMax(min, max)) + 0.5f); +} + +// Converts a float value into an unsigned eight-bit quantized value. +inline uint8_t F2Q(const float value, const float min, const float max) { + int32_t result = ZeroPointFromMinMax(min, max) + + (value / ScaleFromMinMax(min, max)) + 0.5f; + if (result < 0) { + result = 0; + } + if (result > 256) { + result = 256; + } + return result; +} + +// Converts a float value into a signed thirty-two-bit quantized value. +inline uint8_t F2Q32(const float value, const float min, const float max) { + return static_cast((value - ZeroPointFromMinMax(min, max)) / + ScaleFromMinMax(min, max)); +} + +inline void PopulateContext(TfLiteTensor* tensors, int tensors_size, + TfLiteContext* context) { + context->tensors_size = tensors_size; + context->tensors = tensors; + context->impl_ = static_cast(micro_test::reporter); + context->GetExecutionPlan = nullptr; + context->ResizeTensor = nullptr; + context->ReportError = ReportOpError; + context->AddTensors = nullptr; + context->GetNodeAndRegistration = nullptr; + context->ReplaceSubgraphsWithDelegateKernels = nullptr; + context->recommended_num_threads = 1; + context->GetExternalContext = nullptr; + context->SetExternalContext = nullptr; +} + +inline TfLiteIntArray* IntArrayFromInts(const int* int_array) { + return const_cast( + reinterpret_cast(int_array)); +} + +inline TfLiteIntArray* IntArrayFromInitializer( + std::initializer_list int_initializer) { + return IntArrayFromInts(int_initializer.begin()); +} + +inline TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims, + const char* name) { + const size_t bytes = ElementCount(*dims) * sizeof(float); + return { + kTfLiteFloat32, {const_cast(reinterpret_cast(data))}, + dims, {}, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateFloatTensor(std::initializer_list data, + TfLiteIntArray* dims, const char* name) { + return CreateFloatTensor(data.begin(), dims, name); +} + +inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + const size_t bytes = ElementCount(*dims) * sizeof(uint8_t); + const TfLiteQuantizationParams q_params = { + ScaleFromMinMax(min, max), + ZeroPointFromMinMax(min, max)}; + return { + kTfLiteUInt8, {const_cast(reinterpret_cast(data))}, + dims, q_params, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateQuantizedTensor(std::initializer_list data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + return CreateQuantizedTensor(data.begin(), dims, name, min, max); +} + +inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + const size_t bytes = ElementCount(*dims) * sizeof(int32_t); + const TfLiteQuantizationParams q_params = { + ScaleFromMinMax(min, max), + ZeroPointFromMinMax(min, max)}; + return { + kTfLiteUInt8, {const_cast(reinterpret_cast(data))}, + dims, q_params, + kTfLiteMemNone, bytes, + nullptr, name}; +} + +inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list data, + TfLiteIntArray* dims, + const char* name, float min, + float max) { + return CreateQuantized32Tensor(data.begin(), dims, name, min, max); +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..99dd8836611c287b7f76104c29c12a73d219ccb3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" + +#ifdef TF_LITE_MCU_DEBUG_LOG +#include +#else // TF_LITE_MCU_DEBUG_LOG +#include +#include +void DebugLog(const char* s) { fprintf(stderr, "%s", s); } +void DebugLogInt32(int32_t i) { fprintf(stderr, "%d", i); } +void DebugLogUInt32(uint32_t i) { fprintf(stderr, "%d", i); } +void DebugLogHex(uint32_t i) { fprintf(stderr, "0x%8x", i); } +void DebugLogFloat(float i) { fprintf(stderr, "%f", i); } +#endif // TF_LITE_MCU_DEBUG_LOG + +namespace tflite { +namespace { +void DebugLogPrintf(const char* format, va_list args) { + const int output_cache_size = 64; + char output_cache[output_cache_size + 1]; + int output_cache_index = 0; + const char* current = format; + while (*current != 0) { + if (*current == '%') { + const char next = *(current + 1); + if ((next == 'd') || (next == 's')) { + current += 1; + if (output_cache_index > 0) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + if (next == 'd') { + DebugLogInt32(va_arg(args, int)); + } else if (next == 's') { + DebugLog(va_arg(args, char*)); + } + } + } else { + output_cache[output_cache_index] = *current; + output_cache_index += 1; + } + if (output_cache_index >= output_cache_size) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + current += 1; + } + if (output_cache_index > 0) { + output_cache[output_cache_index] = 0; + DebugLog(output_cache); + output_cache_index = 0; + } + DebugLog("\n"); +} +} // namespace + +int MicroErrorReporter::Report(const char* format, va_list args) { + DebugLogPrintf(format, args); + return 0; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..33e54f7990af6cff4f8706d2889c335087581af4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ + +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +namespace tflite { + +class MicroErrorReporter : public ErrorReporter { + public: + ~MicroErrorReporter() {} + int Report(const char* format, va_list args) override; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc similarity index 53% rename from tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h rename to tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc index 86250e6692004a12a1fa338767a5db1e4c2e4195..ef3c32050c0e826c005f185553974170da7e486a 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" -#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" - -#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +int main(int argc, char** argv) { + tflite::MicroErrorReporter micro_error_reporter; + tflite::ErrorReporter* error_reporter = µ_error_reporter; + error_reporter->Report("Number: %d", 42); + error_reporter->Report("Badly-formed format string %"); + error_reporter->Report("Another % badly-formed %% format string"); + error_reporter->Report("~~~%s~~~", "ALL TESTS PASSED"); +} diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f38991bb0ef3d0134b4d9a1eb6e148a140fe6f9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc @@ -0,0 +1,310 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +namespace tflite { +namespace { +const int kStackDataAllocatorSize = 128; +class StackDataAllocator : public BuiltinDataAllocator { + public: + void* Allocate(size_t size) override { + if (size > kStackDataAllocatorSize) { + return nullptr; + } else { + return data_; + } + } + void Deallocate(void* data) override { + // Do nothing. + } + + private: + uint8_t data_[kStackDataAllocatorSize]; + + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +const char* OpNameFromRegistration(const TfLiteRegistration* registration) { + if (registration->builtin_code == BuiltinOperator_CUSTOM) { + return registration->custom_name; + } else { + return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code)); + } +} + +void ReportOpError(struct TfLiteContext* context, const char* format, ...) { + MicroInterpreter* interpreter = + static_cast(context->impl_); + va_list args; + va_start(args, format); + interpreter->error_reporter()->Report(format, args); + va_end(args); +} + +} // namespace + +MicroInterpreter::MicroInterpreter(const Model* model, + const OpResolver& op_resolver, + SimpleTensorAllocator* tensor_allocator, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + tensor_allocator_(tensor_allocator), + error_reporter_(error_reporter), + initialization_status_(kTfLiteOk) { + const flatbuffers::Vector>* buffers = + model->buffers(); + auto* subgraphs = model->subgraphs(); + if (subgraphs->size() != 1) { + error_reporter->Report("Only 1 subgraph is currently supported.\n"); + initialization_status_ = kTfLiteError; + return; + } + subgraph_ = (*subgraphs)[0]; + tensors_ = subgraph_->tensors(); + operators_ = subgraph_->operators(); + + context_.tensors_size = tensors_->Length(); + context_.tensors = + reinterpret_cast(tensor_allocator_->AllocateMemory( + sizeof(TfLiteTensor) * context_.tensors_size)); + for (int i = 0; i < subgraph_->inputs()->Length(); ++i) { + const int tensor_index = subgraph_->inputs()->Get(i); + const auto* tensor = tensors_->Get(tensor_index); + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, 0, operators_->Length(), buffers, error_reporter, + &context_.tensors[tensor_index]); + if (initialization_status_ != kTfLiteOk) { + return; + } + } + + int* first_created = reinterpret_cast( + tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length())); + int* last_used = reinterpret_cast( + tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length())); + for (int i = 0; i < tensors_->Length(); ++i) { + first_created[i] = -1; + last_used[i] = -1; + } + + for (int i = (operators_->Length() - 1); i >= 0; --i) { + const auto* op = operators_->Get(i); + for (int n = 0; n < op->inputs()->Length(); ++n) { + const int tensor_index = op->inputs()->Get(n); + if ((last_used[tensor_index] == -1) || (last_used[tensor_index] < i)) { + last_used[tensor_index] = i; + } + } + for (int n = 0; n < op->outputs()->Length(); ++n) { + const int tensor_index = op->outputs()->Get(n); + const int create_before = i; + int destroy_after = last_used[tensor_index]; + if (destroy_after == -1) { + destroy_after = operators_->Length(); + } + const auto* tensor = tensors_->Get(tensor_index); + if (!tensor->is_variable()) { + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, create_before, destroy_after, buffers, error_reporter, + &context_.tensors[tensor_index]); + if (initialization_status_ != kTfLiteOk) { + return; + } + first_created[tensor_index] = i; + } + } + } + + for (int i = 0; i < tensors_->Length(); ++i) { + const auto* tensor = tensors_->Get(i); + const bool is_read_only = (first_created[i] == -1) && (last_used[i] != -1); + if (tensor->is_variable() || is_read_only) { + initialization_status_ = tensor_allocator_->AllocateTensor( + *tensor, 0, operators_->Length(), buffers, error_reporter, + &context_.tensors[i]); + if (initialization_status_ != kTfLiteOk) { + return; + } + } + } + context_.impl_ = static_cast(this); + context_.GetExecutionPlan = nullptr; + context_.ResizeTensor = nullptr; + context_.ReportError = ReportOpError; + context_.AddTensors = nullptr; + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.recommended_num_threads = 1; + context_.GetExternalContext = nullptr; + context_.SetExternalContext = nullptr; +} + +TfLiteStatus MicroInterpreter::Invoke() { + if (initialization_status_ != kTfLiteOk) { + error_reporter_->Report("Invoke() called after initialization failed\n"); + return kTfLiteError; + } + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (int i = 0; i < operators_->Length(); ++i) { + const auto* op = operators_->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= opcodes->size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + return kTfLiteError; + } + auto opcode = (*opcodes)[index]; + const TfLiteRegistration* registration = nullptr; + status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, + ®istration); + if (status != kTfLiteOk) { + return status; + } + if (registration == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + return kTfLiteError; + } + BuiltinOperator op_type = + static_cast(registration->builtin_code); + + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + StackDataAllocator stack_data_allocator; + const char* custom_data = nullptr; + size_t custom_data_size = 0; + unsigned char* builtin_data = nullptr; + if (op->custom_options()) { + custom_data = reinterpret_cast(op->custom_options()->data()); + custom_data_size = op->custom_options()->size(); + } else { + TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, + &stack_data_allocator, + (void**)(&builtin_data))); + } + + const char* init_data; + size_t init_data_size; + if (registration->builtin_code == BuiltinOperator_CUSTOM) { + init_data = custom_data; + init_data_size = custom_data_size; + } else { + init_data = reinterpret_cast(builtin_data); + init_data_size = 0; + } + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context_, init_data, init_data_size); + } + + const int kMaxInputs = 16; + int inputs_data[kMaxInputs + 1]; + TfLiteIntArray* inputs_array = + reinterpret_cast(inputs_data); + if (op->inputs()->Length() >= kMaxInputs) { + error_reporter_->Report("Too many inputs (%d)\n", op->inputs()->Length()); + return kTfLiteError; + } + inputs_array->size = op->inputs()->Length(); + for (int n = 0; n < op->inputs()->Length(); ++n) { + inputs_array->data[n] = op->inputs()->Get(n); + } + + const int kMaxOutputs = 16; + int outputs_data[kMaxOutputs + 1]; + TfLiteIntArray* outputs_array = + reinterpret_cast(outputs_data); + if (op->outputs()->Length() >= kMaxOutputs) { + error_reporter_->Report("Too many outputs (%d)\n", + op->outputs()->Length()); + return kTfLiteError; + } + outputs_array->size = op->outputs()->Length(); + for (int n = 0; n < op->outputs()->Length(); ++n) { + outputs_array->data[n] = op->outputs()->Get(n); + } + + const int kMaxTemporaries = 16; + int temporaries_data[kMaxTemporaries + 1]; + TfLiteIntArray* temporaries_array = + reinterpret_cast(temporaries_data); + temporaries_array->size = 0; + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast(builtin_data); + node.custom_initial_data = custom_data; + node.custom_initial_data_size = custom_data_size; + node.delegate = nullptr; + if (registration->prepare) { + TfLiteStatus prepare_status = registration->prepare(&context_, &node); + if (prepare_status != kTfLiteOk) { + error_reporter_->Report( + "Node %s (number %d) failed to prepare with status %d", + OpNameFromRegistration(registration), i, prepare_status); + return kTfLiteError; + } + } + + if (registration->invoke) { + TfLiteStatus invoke_status = registration->invoke(&context_, &node); + if (invoke_status != kTfLiteOk) { + error_reporter_->Report( + "Node %s (number %d) failed to invoke with status %d", + OpNameFromRegistration(registration), i, invoke_status); + return kTfLiteError; + } + } + + if (registration->free) { + registration->free(&context_, user_data); + } + } + return status; +} + +TfLiteTensor* MicroInterpreter::input(int index) { + const flatbuffers::Vector* inputs = subgraph_->inputs(); + const size_t length = inputs->Length(); + if ((index < 0) || (index >= length)) { + error_reporter_->Report("Input index %d out of range (length is %d)", index, + length); + return nullptr; + } + return &(context_.tensors[inputs->Get(index)]); +} + +TfLiteTensor* MicroInterpreter::output(int index) { + const flatbuffers::Vector* outputs = subgraph_->outputs(); + const size_t length = outputs->Length(); + if ((index < 0) || (index >= outputs->Length())) { + error_reporter_->Report("Output index %d out of range (length is %d)", + index, length); + return nullptr; + } + return &(context_.tensors[outputs->Get(index)]); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..a88514cde849595244d36a31900e6d1c2ae1714b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +class MicroInterpreter { + public: + // The lifetime of the model, op resolver, allocator, and error reporter must + // be at least as long as that of the interpreter object, since the + // interpreter may need to access them at any time. This means that you should + // usually create them with the same scope as each other, for example having + // them all allocated on the stack as local variables through a top-level + // function. + // The interpreter doesn't do any deallocation of any of the pointed-to + // objects, ownership remains with the caller. + MicroInterpreter(const Model* model, const OpResolver& op_resolver, + SimpleTensorAllocator* tensor_allocator, + ErrorReporter* error_reporter); + + TfLiteStatus Invoke(); + + size_t tensors_size() const { return context_.tensors_size; } + TfLiteTensor* tensor(int tensor_index); + + TfLiteTensor* input(int index); + size_t inputs_size() const { return subgraph_->inputs()->Length(); } + + TfLiteTensor* output(int index); + size_t outputs_size() const { return subgraph_->outputs()->Length(); } + + TfLiteStatus initialization_status() const { return initialization_status_; } + + ErrorReporter* error_reporter() { return error_reporter_; } + + private: + const Model* model_; + const OpResolver& op_resolver_; + SimpleTensorAllocator* tensor_allocator_; + ErrorReporter* error_reporter_; + + TfLiteStatus initialization_status_; + const flatbuffers::Vector>* tensors_; + const flatbuffers::Vector>* operators_; + TfLiteContext context_; + + const SubGraph* subgraph_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..251e5f72037717f74bc3472b69144cff299f0668 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + const int32_t* input_data = input->data.i32; + const TfLiteTensor* weight = &context->tensors[node->inputs->data[1]]; + const uint8_t* weight_data = weight->data.uint8; + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + int32_t* output_data = output->data.i32; + output_data[0] = input_data[0] + weight_data[0]; + return kTfLiteOk; +} + +class MockOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(BuiltinOperator op, + int version) const override { + return nullptr; + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (strcmp(op, "mock_custom") == 0) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } +}; + +class StackAllocator : public flatbuffers::Allocator { + public: + StackAllocator() : data_(data_backing_), data_size_(0) {} + + uint8_t* allocate(size_t size) override { + if ((data_size_ + size) > kStackAllocatorSize) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; + } + + void deallocate(uint8_t* p, size_t) override {} + + static StackAllocator& instance() { + // Avoid using true dynamic memory allocation to be portable to bare metal. + static char inst_memory[sizeof(StackAllocator)]; + static StackAllocator* inst = new (inst_memory) StackAllocator; + return *inst; + } + + static constexpr int kStackAllocatorSize = 4096; + + private: + uint8_t data_backing_[kStackAllocatorSize]; + uint8_t* data_; + int data_size_; +}; + +const Model* BuildMockModel() { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder builder(StackAllocator::kStackAllocatorSize, + &StackAllocator::instance()); + constexpr size_t buffer_data_size = 1; + const uint8_t buffer_data[buffer_data_size] = {21}; + constexpr size_t buffers_size = 2; + const Offset buffers[buffers_size] = { + CreateBuffer(builder), + CreateBuffer(builder, + builder.CreateVector(buffer_data, buffer_data_size))}; + constexpr size_t tensor_shape_size = 1; + const int32_t tensor_shape[tensor_shape_size] = {1}; + constexpr size_t tensors_size = 3; + const Offset tensors[tensors_size] = { + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder.CreateString("test_input_tensor"), 0, false), + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_UINT8, 1, + builder.CreateString("test_weight_tensor"), 0, false), + CreateTensor(builder, + builder.CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, + builder.CreateString("test_output_tensor"), 0, false), + }; + constexpr size_t inputs_size = 1; + const int32_t inputs[inputs_size] = {0}; + constexpr size_t outputs_size = 1; + const int32_t outputs[outputs_size] = {2}; + constexpr size_t operator_inputs_size = 2; + const int32_t operator_inputs[operator_inputs_size] = {0, 1}; + constexpr size_t operator_outputs_size = 1; + const int32_t operator_outputs[operator_outputs_size] = {2}; + constexpr size_t operators_size = 1; + const Offset operators[operators_size] = {CreateOperator( + builder, 0, builder.CreateVector(operator_inputs, operator_inputs_size), + builder.CreateVector(operator_outputs, operator_outputs_size), + BuiltinOptions_NONE)}; + constexpr size_t subgraphs_size = 1; + const Offset subgraphs[subgraphs_size] = { + CreateSubGraph(builder, builder.CreateVector(tensors, tensors_size), + builder.CreateVector(inputs, inputs_size), + builder.CreateVector(outputs, outputs_size), + builder.CreateVector(operators, operators_size), + builder.CreateString("test_subgraph"))}; + constexpr size_t operator_codes_size = 1; + const Offset operator_codes[operator_codes_size] = { + CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "mock_custom", + 0)}; + const Offset model_offset = CreateModel( + builder, 0, builder.CreateVector(operator_codes, operator_codes_size), + builder.CreateVector(subgraphs, subgraphs_size), + builder.CreateString("test_model"), + builder.CreateVector(buffers, buffers_size)); + FinishModelBuffer(builder, model_offset); + void* model_pointer = builder.GetBufferPointer(); + const Model* model = flatbuffers::GetRoot(model_pointer); + return model; +} + +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestInterpreter) { + const tflite::Model* model = tflite::BuildMockModel(); + TF_LITE_MICRO_EXPECT_NE(nullptr, model); + tflite::MockOpResolver mock_resolver; + constexpr size_t allocator_buffer_size = 1024; + uint8_t allocator_buffer[allocator_buffer_size]; + tflite::SimpleTensorAllocator simple_tensor_allocator(allocator_buffer, + allocator_buffer_size); + tflite::MicroInterpreter interpreter( + model, mock_resolver, &simple_tensor_allocator, micro_test::reporter); + TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size()); + TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size()); + + TfLiteTensor* input = interpreter.input(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, input); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, input->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32); + input->data.i32[0] = 21; + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke()); + + TfLiteTensor* output = interpreter.output(0); + TF_LITE_MICRO_EXPECT_NE(nullptr, output); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size); + TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(4, output->bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32); + TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..40c21c6448c39f27c12e95ae36038510cb346362 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +namespace tflite { + +const TfLiteRegistration* MicroMutableOpResolver::FindOp( + tflite::BuiltinOperator op, int version) const { + for (int i = 0; i < registrations_len_; ++i) { + const TfLiteRegistration& registration = registrations_[i]; + if ((registration.builtin_code == op) && + (registration.version == version)) { + return ®istration; + } + } + return nullptr; +} + +const TfLiteRegistration* MicroMutableOpResolver::FindOp(const char* op, + int version) const { + for (int i = 0; i < registrations_len_; ++i) { + const TfLiteRegistration& registration = registrations_[i]; + if ((registration.builtin_code == -1) && + (strcmp(registration.custom_name, op) == 0) && + (registration.version == version)) { + return ®istration; + } + } + return nullptr; +} + +void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) { + // TODO(petewarden) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; + + *new_registration = *registration; + new_registration->builtin_code = op; + new_registration->version = version; + } +} + +void MicroMutableOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) { + // TODO(petewarden) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; + + *new_registration = *registration; + new_registration->builtin_code = -1; + new_registration->custom_name = name; + new_registration->version = version; + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..f3750a248416cc7244e0dea82be167562fd59ee7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/experimental/micro/compatibility.h" + +#ifndef TFLITE_REGISTRATIONS_MAX +#define TFLITE_REGISTRATIONS_MAX (128) +#endif + +namespace tflite { + +class MicroMutableOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + + private: + TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX]; + int registrations_len_ = 0; + + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5420a33e8778d93d5aad2150438fdba80df372b8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestOperations) { + using tflite::BuiltinOperator_CONV_2D; + using tflite::BuiltinOperator_RELU; + using tflite::MicroMutableOpResolver; + using tflite::OpResolver; + + static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, + tflite::MockPrepare, tflite::MockInvoke}; + + MicroMutableOpResolver micro_mutable_op_resolver; + micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2); + micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3); + OpResolver* resolver = µ_mutable_op_resolver; + + const TfLiteRegistration* registration = + resolver->FindOp(BuiltinOperator_CONV_2D, 0); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp(BuiltinOperator_RELU, 0); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("mock_custom", 0); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp("mock_custom", 10); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("nonexistent_custom", 0); + TF_LITE_MICRO_EXPECT_EQ(nullptr, registration); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c090a20a5fb9e6cb330a40c86236c549c28539e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h" + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +namespace tflite { +namespace { + +TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size, + ErrorReporter* reporter) { + switch (type) { + case kTfLiteFloat32: + *size = sizeof(float); + break; + case kTfLiteInt16: + *size = sizeof(int16_t); + break; + case kTfLiteInt32: + *size = sizeof(int32_t); + break; + case kTfLiteUInt8: + *size = sizeof(uint8_t); + break; + case kTfLiteInt64: + *size = sizeof(int64_t); + break; + case kTfLiteBool: + *size = sizeof(bool); + break; + case kTfLiteComplex64: + *size = sizeof(float) * 2; + break; + default: + reporter->Report( + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus BytesRequired(const tflite::Tensor& flatbuffer_tensor, + size_t dims_size, size_t* bytes, + ErrorReporter* error_reporter) { + TfLiteType tf_lite_type; + TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(), + &tf_lite_type, error_reporter)); + size_t type_size; + TF_LITE_ENSURE_STATUS( + TfLiteTypeSizeOf(tf_lite_type, &type_size, error_reporter)); + *bytes = dims_size * type_size; + return kTfLiteOk; +} + +} // namespace + +TfLiteStatus SimpleTensorAllocator::AllocateTensor( + const tflite::Tensor& flatbuffer_tensor, int create_before, + int destroy_after, + const flatbuffers::Vector>* buffers, + ErrorReporter* error_reporter, TfLiteTensor* result) { + TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(), + &result->type, error_reporter)); + result->is_variable = flatbuffer_tensor.is_variable(); + + result->data.raw = nullptr; + result->bytes = 0; + if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) { + if (auto* array = buffer->data()) { + if (size_t array_size = array->size()) { + result->data.raw = + const_cast(reinterpret_cast(array->data())); + TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, array_size, + &result->bytes, error_reporter)); + } + } + } + if (result->data.raw) { + result->allocation_type = kTfLiteMmapRo; + } else { + int data_size = 1; + for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) { + data_size *= flatbuffer_tensor.shape()->Get(n); + } + TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, data_size, + &result->bytes, error_reporter)); + result->data.raw = reinterpret_cast(AllocateMemory(result->bytes)); + if (result->data.raw == nullptr) { + const char* tensor_name = flatbuffer_tensor.name()->c_str(); + if (tensor_name == nullptr) { + tensor_name = ""; + } + error_reporter->Report( + "Couldn't allocate memory for tensor '%s', wanted %d bytes but only " + "%d were available", + tensor_name, result->bytes, (data_size_max_ - data_size_)); + return kTfLiteError; + } + result->allocation_type = kTfLiteArenaRw; + } + result->dims = reinterpret_cast( + AllocateMemory(sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1))); + result->dims->size = flatbuffer_tensor.shape()->Length(); + for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) { + result->dims->data[n] = flatbuffer_tensor.shape()->Get(n); + } + if (flatbuffer_tensor.quantization()) { + result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0); + result->params.zero_point = + flatbuffer_tensor.quantization()->zero_point()->Get(0); + } + result->allocation = nullptr; + if (flatbuffer_tensor.name()) { + result->name = flatbuffer_tensor.name()->c_str(); + } else { + result->name = ""; + } + result->delegate = nullptr; + result->buffer_handle = 0; + result->data_is_stale = false; + return kTfLiteOk; +} + +uint8_t* SimpleTensorAllocator::AllocateMemory(size_t size) { + if ((data_size_ + size) > data_size_max_) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..4f16a9d0e54cba6fb3b635ceeb39ab10ff59ae73 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// TODO(petewarden): This allocator never frees up or reuses any memory, even +// though we have enough information about lifetimes of the tensors to do so. +// This makes it pretty wasteful, so we should use a more intelligent method. +class SimpleTensorAllocator { + public: + SimpleTensorAllocator(uint8_t* buffer, int buffer_size) + : data_size_(0), data_size_max_(buffer_size), data_(buffer) {} + + TfLiteStatus AllocateTensor( + const tflite::Tensor& flatbuffer_tensor, int create_before, + int destroy_after, + const flatbuffers::Vector>* buffers, + ErrorReporter* error_reporter, TfLiteTensor* result); + + uint8_t* AllocateMemory(size_t size); + + int GetDataSize() const { return data_size_; } + + private: + int data_size_; + int data_size_max_; + uint8_t* data_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c83542724395328cb6a5e038b64dba4b9f4f655b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h" + +#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h" + +namespace tflite { +namespace { +class StackAllocator : public flatbuffers::Allocator { + public: + StackAllocator() : data_(data_backing_), data_size_(0) {} + + uint8_t* allocate(size_t size) override { + if ((data_size_ + size) > kStackAllocatorSize) { + // TODO(petewarden): Add error reporting beyond returning null! + return nullptr; + } + uint8_t* result = data_; + data_ += size; + data_size_ += size; + return result; + } + + void deallocate(uint8_t* p, size_t) override {} + + static StackAllocator& instance() { + // Avoid using true dynamic memory allocation to be portable to bare metal. + static char inst_memory[sizeof(StackAllocator)]; + static StackAllocator* inst = new (inst_memory) StackAllocator; + return *inst; + } + + static constexpr int kStackAllocatorSize = 4096; + + private: + uint8_t data_backing_[kStackAllocatorSize]; + uint8_t* data_; + int data_size_; +}; + +flatbuffers::FlatBufferBuilder* BuilderInstance() { + static char inst_memory[sizeof(flatbuffers::FlatBufferBuilder)]; + static flatbuffers::FlatBufferBuilder* inst = + new (inst_memory) flatbuffers::FlatBufferBuilder( + StackAllocator::kStackAllocatorSize, &StackAllocator::instance()); + return inst; +} + +const Tensor* Create1dTensor(int size) { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + constexpr size_t tensor_shape_size = 1; + const int32_t tensor_shape[tensor_shape_size] = {size}; + const Offset tensor_offset = CreateTensor( + *builder, builder->CreateVector(tensor_shape, tensor_shape_size), + TensorType_INT32, 0, builder->CreateString("test_tensor"), 0, false); + builder->Finish(tensor_offset); + void* tensor_pointer = builder->GetBufferPointer(); + const Tensor* tensor = flatbuffers::GetRoot(tensor_pointer); + return tensor; +} + +const flatbuffers::Vector>* CreateBuffers() { + using flatbuffers::Offset; + flatbuffers::FlatBufferBuilder* builder = BuilderInstance(); + constexpr size_t buffers_size = 1; + const Offset buffers[buffers_size] = { + CreateBuffer(*builder), + }; + const flatbuffers::Offset>> + buffers_offset = builder->CreateVector(buffers, buffers_size); + builder->Finish(buffers_offset); + void* buffers_pointer = builder->GetBufferPointer(); + const flatbuffers::Vector>* result = + flatbuffers::GetRoot>>( + buffers_pointer); + return result; +} + +} // namespace +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestAllocateTensor) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + const tflite::Tensor* tensor = tflite::Create1dTensor(100); + const flatbuffers::Vector>* buffers = + tflite::CreateBuffers(); + + TfLiteTensor allocated_tensor; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter, + &allocated_tensor)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type); + TF_LITE_MICRO_EXPECT_EQ(1, allocated_tensor.dims->size); + TF_LITE_MICRO_EXPECT_EQ(100, allocated_tensor.dims->data[0]); + TF_LITE_MICRO_EXPECT_EQ(400, allocated_tensor.bytes); + TF_LITE_MICRO_EXPECT_NE(nullptr, allocated_tensor.data.i32); +} + +TF_LITE_MICRO_TEST(TestTooLarge) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + const tflite::Tensor* tensor = tflite::Create1dTensor(10000); + const flatbuffers::Vector>* buffers = + tflite::CreateBuffers(); + + TfLiteTensor allocated_tensor; + TF_LITE_MICRO_EXPECT_NE( + kTfLiteOk, + allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter, + &allocated_tensor)); +} + +TF_LITE_MICRO_TEST(TestJustFits) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleTensorAllocator allocator(arena, arena_size); + + uint8_t* result = allocator.AllocateMemory(arena_size); + TF_LITE_MICRO_EXPECT_NE(nullptr, result); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/contrib/lite/experimental/micro/testing/BUILD b/tensorflow/contrib/lite/experimental/micro/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0d23be5712ad1bc6d81cc467cce8c9927caece3d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/BUILD @@ -0,0 +1,17 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["test_linux_binary.sh"]) + +cc_library( + name = "micro_test", + hdrs = [ + "micro_test.h", + ], + deps = [ + "//tensorflow/contrib/lite/experimental/micro:micro_framework", + ], +) diff --git a/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill new file mode 100644 index 0000000000000000000000000000000000000000..7d6d81af0f482afb7a9f0624b5262a5277112976 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill @@ -0,0 +1,21 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This docker configuration file lets you emulate a Blue Pill board +# on an x86 desktop or laptop, which can be useful for debugging and +# automated testing. +FROM antmicro/renode:latest + +LABEL maintainer="Pete Warden " \ No newline at end of file diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc new file mode 100644 index 0000000000000000000000000000000000000000..9333dc42bfbfbc0c6185a88db096b2cb2102d5be --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc @@ -0,0 +1,36 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +using sysbus + +mach create +machine LoadPlatformDescription @platforms/cpus/stm32f103.repl + +# These lines are needed to show the results of DebugLog calls in the output. +machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu" +showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer + +logFile @/tmp/renode_bluepill_log.txt + +macro reset +""" + sysbus LoadELF $bin +""" + +runMacro $reset + +emulation RunFor @1 + +quit \ No newline at end of file diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl new file mode 100644 index 0000000000000000000000000000000000000000..916e3eeac394f9a815d7c1785d253fd54ca7aa0e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl @@ -0,0 +1,67 @@ +"""Rules for simple testing without dependencies by parsing output logs.""" + +def tflite_micro_cc_test( + name, + expected_in_logs = "~~~ALL TESTS PASSED~~~", + srcs = [], + includes = [], + defines = [], + copts = [], + nocopts = "", + linkopts = [], + deps = [], + tags = [], + visibility = None): + """Tests a C/C++ binary without testing framework dependencies`. + + Runs a C++ binary, and tests that the output logs contain the + expected value. This is a deliberately spartan way of testing, to match + what's available when testing microcontroller binaries. + + Args: + name: a unique name for this rule. + expected_in_logs: A regular expression that is required to be + present in the binary's logs for the test to pass. + srcs: sources to compile (C, C++, ld scripts). + includes: include paths to add to this rule and its dependents. + defines: list of `VAR` or `VAR=VAL` to pass to CPP for this rule and + its dependents. + copts: gcc compilation flags for this rule only. + nocopts: list of gcc compilation flags to remove for this rule + only. No regexp like for `cc_library`. + linkopts: `gcc` flags to add to the linking phase. For "pure" ld flags, + prefix them with the `-Wl,` prefix here. + deps: dependencies. only `tflite_bare_metal_cc_library()` dependencies + allowed. + visibility: visibility. + """ + native.cc_binary( + name = name + "_binary", + srcs = srcs, + includes = includes, + defines = defines, + copts = copts, + nocopts = nocopts, + linkopts = linkopts, + deps = deps, + tags = tags, + visibility = visibility, + ) + native.sh_test( + name = name, + size = "medium", + srcs = [ + "//tensorflow/contrib/lite/experimental/micro/testing:test_linux_binary.sh", + ], + args = [ + native.package_name() + "/" + name + "_binary", + "'" + expected_in_logs + "'", + ], + data = [ + name + "_binary", + # Internal test dependency placeholder + ], + deps = [ + ], + tags = tags, + ) diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h new file mode 100644 index 0000000000000000000000000000000000000000..104509c9dc6123e84c45f26d03465f608f100310 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An ultra-lightweight testing framework designed for use with microcontroller +// applications. Its only dependency is on TensorFlow Lite's ErrorReporter +// interface, where log messages are output. This is designed to be usable even +// when no standard C or C++ libraries are available, and without any dynamic +// memory allocation or reliance on global constructors. +// +// To build a test, you use syntax similar to gunit, but with some extra +// decoration to create a hidden 'main' function containing each of the tests to +// be run. Your code should look something like: +// ---------------------------------------------------------------------------- +// #include "path/to/this/header" +// +// TF_LITE_MICRO_TESTS_BEGIN +// +// TF_LITE_MICRO_TEST(SomeTest) { +// TF_LITE_LOG_EXPECT_EQ(true, true); +// } +// +// TF_LITE_MICRO_TESTS_END +// ---------------------------------------------------------------------------- +// If you compile this for your platform, you'll get a normal binary that you +// should be able to run. Executing it will output logging information like this +// to stderr (or whatever equivalent is available and written to by +// ErrorReporter): +// ---------------------------------------------------------------------------- +// Testing SomeTest +// 1/1 tests passed +// ~~~ALL TESTS PASSED~~~ +// ---------------------------------------------------------------------------- +// This is designed to be human-readable, so you can just run tests manually, +// but the string "~~~ALL TESTS PASSED~~~" should only appear if all of the +// tests do pass. This makes it possible to integrate with automated test +// systems by scanning the output logs and looking for that magic value. +// +// This framework is intended to be a rudimentary alternative to no testing at +// all on systems that struggle to run more conventional approaches, so use with +// caution! + +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ + +#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h" + +namespace micro_test { +extern int tests_passed; +extern int tests_failed; +extern bool is_test_complete; +extern bool did_test_fail; +extern tflite::ErrorReporter* reporter; +} // namespace micro_test + +#define TF_LITE_MICRO_TESTS_BEGIN \ + namespace micro_test { \ + int tests_passed; \ + int tests_failed; \ + bool is_test_complete; \ + bool did_test_fail; \ + tflite::ErrorReporter* reporter; \ + } \ + \ + int main(int argc, char** argv) { \ + micro_test::tests_passed = 0; \ + micro_test::tests_failed = 0; \ + tflite::MicroErrorReporter error_reporter; \ + micro_test::reporter = &error_reporter; + +#define TF_LITE_MICRO_TESTS_END \ + micro_test::reporter->Report( \ + "%d/%d tests passed", micro_test::tests_passed, \ + (micro_test::tests_failed + micro_test::tests_passed)); \ + if (micro_test::tests_failed == 0) { \ + micro_test::reporter->Report("~~~ALL TESTS PASSED~~~\n"); \ + } else { \ + micro_test::reporter->Report("~~~SOME TESTS FAILED~~~\n"); \ + } \ + } + +// TODO(petewarden): I'm going to hell for what I'm doing to this poor for loop. +#define TF_LITE_MICRO_TEST(name) \ + micro_test::reporter->Report("Testing %s", #name); \ + for (micro_test::is_test_complete = false, \ + micro_test::did_test_fail = false; \ + !micro_test::is_test_complete; micro_test::is_test_complete = true, \ + micro_test::tests_passed += (micro_test::did_test_fail) ? 0 : 1, \ + micro_test::tests_failed += (micro_test::did_test_fail) ? 1 : 0) + +#define TF_LITE_MICRO_EXPECT(x) \ + do { \ + if (!(x)) { \ + micro_test::reporter->Report(#x " failed at %s:%d", __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_EQ(x, y) \ + do { \ + if ((x) != (y)) { \ + micro_test::reporter->Report(#x " == " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_NE(x, y) \ + do { \ + if ((x) == (y)) { \ + micro_test::reporter->Report(#x " != " #y " failed at %s:%d", __FILE__, \ + __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ + do { \ + auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \ + if (delta > epsilon) { \ + micro_test::reporter->Report(#x " near " #y " failed at %s:%d", \ + __FILE__, __LINE__); \ + micro_test::did_test_fail = true; \ + } \ + } while (false) + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_ diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh new file mode 100755 index 0000000000000000000000000000000000000000..07742a8262f8cdf5981be2a057631d975cd04d33 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh @@ -0,0 +1,54 @@ +#!/bin/bash -e +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests a 'bluepill' STM32F103 ELF by parsing the log output of Renode emulation. +# +# First argument is the ELF location. +# Second argument is a regular expression that's required to be in the output logs +# for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR} +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +docker build -t renode_bluepill \ + -f ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill \ + ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/ + +docker run \ + --log-driver=none -a stdout -a stderr \ + -v ${ROOT_DIR}:/workspace \ + -v /tmp:/tmp \ + -it renode_bluepill \ + /bin/bash -c "renode -P 5000 --disable-xwt -e ' +\$bin?=@/workspace/$1 +s @/workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc +' 2>&1 >${MICRO_LOG_FILENAME}" + +echo "LOGS:" +cat ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi + diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh new file mode 100755 index 0000000000000000000000000000000000000000..24131a6d2df6c0187696b7c21efba2323ef1a305 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh @@ -0,0 +1,39 @@ +#!/bin/bash -e +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests a Linux binary by parsing the log output. +# +# First argument is the binary location. +# Second argument is a regular expression that's required to be in the output logs +# for the test to pass. + +declare -r ROOT_DIR=`pwd` +declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/ +declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1 +declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt +mkdir -p ${MICRO_LOG_PATH} + +$1 2>&1 | tee ${MICRO_LOG_FILENAME} + +if grep -q "$2" ${MICRO_LOG_FILENAME} +then + echo "$1: PASS" + exit 0 +else + echo "$1: FAIL - '$2' not found in logs." + exit 1 +fi + diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..880bb4763cbbaf58db286ff142a822fbab60dfd8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile @@ -0,0 +1,166 @@ +MAKEFILE_DIR := tensorflow/contrib/lite/experimental/micro/tools/make + +# Try to figure out the host system +HOST_OS := +ifeq ($(OS),Windows_NT) + HOST_OS = windows +else + UNAME_S := $(shell uname -s) + ifeq ($(UNAME_S),Linux) + HOST_OS := linux + endif + ifeq ($(UNAME_S),Darwin) + HOST_OS := osx + endif +endif + +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Override these on the make command line to target a specific architecture. For example: +# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +TARGET := $(HOST_OS) +TARGET_ARCH := $(HOST_ARCH) + +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(OBJDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + +TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh + +MICROLITE_LIBS := -lm + +# There are no rules for compiling objects for the host system (since we don't +# generate things like the protobuf compiler that require that), so all of +# these settings are for the target compiler. +CXXFLAGS := -O3 -DNDEBUG +CXXFLAGS += --std=c++11 -g -DTF_LITE_STATIC_MEMORY +CCFLAGS := -DNDEBUG -g -DTF_LITE_STATIC_MEMORY +LDOPTS := -L/usr/local/lib +ARFLAGS := -r +TARGET_TOOLCHAIN_PREFIX := +CC_PREFIX := + +# This library is the main target for this makefile. It will contain a minimal +# runtime that can be linked in to other programs. +MICROLITE_LIB_NAME := libtensorflow-microlite.a + +# Test binary for the microcontroller speech model. +MICRO_SPEECH_TEST_SRCS := \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \ +tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc + +MICROLITE_TEST_SRCS := \ +$(wildcard tensorflow/contrib/lite/experimental/micro/*test.cc) \ +$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*test.cc) + +MICROLITE_CC_BASE_SRCS := \ +$(wildcard tensorflow/contrib/lite/experimental/micro/*.cc) \ +$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*.cc) \ +tensorflow/contrib/lite/c/c_api_internal.c \ +tensorflow/contrib/lite/core/api/error_reporter.cc \ +tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc \ +tensorflow/contrib/lite/core/api/op_resolver.cc \ +tensorflow/contrib/lite/kernels/kernel_util.cc \ +tensorflow/contrib/lite/kernels/internal/quantization_util.cc +MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) + +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +ALL_SRCS := \ + $(MICRO_SPEECH_TEST_SRCS) \ + $(MICROLITE_CC_SRCS) \ + $(MICROLITE_TEST_SRCS) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ + +MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME) + +MICRO_SPEECH_TEST_BINARY := $(BINDIR)micro_speech_test + +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar + +MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_TEST_SRCS)))) + +MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS)))) + +MICROLITE_TEST_TARGETS := $(addprefix $(BINDIR), \ +$(patsubst %_test.cc,%.test_target,$(MICROLITE_TEST_SRCS))) + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# For normal manually-created TensorFlow C source files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ + +# The target that's compiled if there's no command-line arguments. +all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY) + +microlite: $(MICROLITE_LIB_PATH) + +# Hack for generating schema file bypassing flatbuffer parsing +tensorflow/contrib/lite/schema/schema_generated.h: + @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h + +# Gathers together all the objects we've compiled into a single '.a' archive. +$(MICROLITE_LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(MICROLITE_LIB_PATH) $(MICROLITE_LIB_OBJS) + +$(MICRO_SPEECH_TEST_BINARY): $(MICRO_SPEECH_TEST_OBJS) $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(MICRO_SPEECH_TEST_BINARY) $(MICRO_SPEECH_TEST_OBJS) \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) + +micro_speech_test: $(MICRO_SPEECH_TEST_BINARY) +micro_speech_test_bin: $(MICRO_SPEECH_TEST_BINARY).bin + +test_micro_speech: $(MICRO_SPEECH_TEST_BINARY) + $(TEST_SCRIPT) $(MICRO_SPEECH_TEST_BINARY) '~~~ALL TESTS PASSED~~~' + +$(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $@ $< \ + $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS) + +$(BINDIR)%.test_target: $(BINDIR)%_test + $(TEST_SCRIPT) $< '~~~ALL TESTS PASSED~~~' + +$(info $(MICROLITE_TEST_TARGETS)) + +test: test_micro_speech $(MICROLITE_TEST_TARGETS) + +# Gets rid of all generated files. +clean: + rm -rf $(MAKEFILE_DIR)/gen + +$(DEPDIR)/%.d: ; +.PRECIOUS: $(DEPDIR)/%.d +.PRECIOUS: $(BINDIR)%_test + +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS))) diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh new file mode 100755 index 0000000000000000000000000000000000000000..4c2ff8545dbdcc426bf62aaeb07ca22d8b17cc69 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../../../../../.." + +DOWNLOADS_DIR=tensorflow/contrib/lite/experimental/micro/tools/make/downloads +BZL_FILE_PATH=tensorflow/workspace.bzl + +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz" +CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip" +STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/50e0da307a2821bb54af1f57b969e6b76cb89d32.zip" + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + if [[ "${url}" == *gz ]]; then + curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz + elif [[ "${url}" == *zip ]]; then + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} + fi + + # Delete any potential BUILD files, which would interfere with Bazel builds. + find "${dir}" -type f -name '*BUILD' -delete +} + +download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" +download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${CMSIS_URL}" "${DOWNLOADS_DIR}/cmsis" +download_and_extract "${STM32_BARE_LIB_URL}" "${DOWNLOADS_DIR}/stm32_bare_lib" + +echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..022a8422dc89c048797d0f9ba224f67060d210d7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc @@ -0,0 +1,65 @@ +# Settings for Blue Pill platforms. +ifeq ($(TARGET), bluepill) + TARGET_ARCH := cortex-m3 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + PLATFORM_FLAGS = \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTF_LITE_STATIC_MEMORY \ + -DTF_LITE_MCU_DEBUG_LOG \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-unwind-tables \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m3 \ + -mthumb \ + -std=gnu++11 \ + -Wvla \ + -Wall \ + -Wextra \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -Wno-write-strings \ + -Wno-sign-compare \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -fpermissive \ + -nostdlib \ + -g \ + -Os + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + LDFLAGS += \ + -T $(MAKEFILE_DIR)/downloads/stm32_bare_lib/stm32_linker_layout.lds \ + -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \ + -Wl,--gc-sections + BUILD_TYPE := micro + MICROLITE_LIBS := \ + -lm + INCLUDES += \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include + MICROLITE_CC_SRCS += \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) + TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh + # These are tests that don't currently work on the blue pill. + EXCLUDED_TESTS := \ + tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \ + tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc + MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) + +# These are microcontroller-specific rules for converting the ELF output +# of the linker into a binary image that can be loaded directly. +OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy + +$(BINDIR)/%.bin: $(BINDIR)/% + @mkdir -p $(dir $@) + $(OBJCOPY) $< $@ -O binary + +endif \ No newline at end of file diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc index e6d5a776b328812369aad7270bfafe4b74a88331..b35c6e06553c44c10979cd3edb68fa76638e6602 100644 --- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include #include -#include "flatbuffers/minireflect.h" // flatbuffers +#include "flatbuffers/minireflect.h" // TF:flatbuffers #include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" namespace tflite { diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc index 52b17faf82e9cd5cb402304a53ddf02c09a9d9d5..555a9cc4b09f30e2344ff30c409d2d2c37e6ea41 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc @@ -117,6 +117,8 @@ Offset>> InterpreterWriter::ExportOperators( Offset>> InterpreterWriter::ExportTensors( FlatBufferBuilder* fbb) { + // Initialized to -1. + // A value of -1 means this tensor will not be exported. tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); std::vector> tensors; @@ -135,15 +137,17 @@ Offset>> InterpreterWriter::ExportTensors( int curr_output_index = 0; for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); tensor_index++) { - if (!tensor_is_temporary[tensor_index]) { + // Temporary tensors and unused tensors will not be written. + if (!tensor_is_temporary[tensor_index] && + unused_tensors_.find(tensor_index) == unused_tensors_.end()) { tensor_to_written_tensor_[tensor_index] = curr_output_index++; } } for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); ++tensor_index) { - // Skip temporaries. - if (tensor_is_temporary[tensor_index]) continue; + // Tensor not exported. + if (tensor_to_written_tensor_[tensor_index] == -1) continue; if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { // We only need to convert non temporaries @@ -215,7 +219,9 @@ std::vector InterpreterWriter::RemapTensorIndicesToWritten( std::vector output; output.reserve(input.size()); for (int x : input) { - output.push_back(tensor_to_written_tensor_[x]); + if (tensor_to_written_tensor_[x] != -1) { + output.push_back(tensor_to_written_tensor_[x]); + } } return output; } diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h index a98108b4960be5cf6243f21f7ff5a6925113e427..a5f14697cfd223a637770e66bdc02278383144b2 100644 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h @@ -62,6 +62,10 @@ class InterpreterWriter { // caller to change the custom data. TfLiteStatus RegisterCustomWriter(const std::string& custom_name, CustomWriter custom_writer); + // Tensors that are unused and shouldn't be written. + void SetUnusedTensors(const std::set& unused_tensors) { + unused_tensors_ = unused_tensors; + } private: template @@ -111,8 +115,9 @@ class InterpreterWriter { int builtin; std::string custom; }; + std::set unused_tensors_; // For every tensor index in the interpreter, the index in the written. - // This is different due to temporary tensors not being written. + // This is different due to temporary and unused tensors not being written. std::vector tensor_to_written_tensor_; // List of used opcodes std::vector opcodes_; diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml index 1dffe30790aac03b32f11b6a9035d187e79edd18..f6ec387ad2feb488023941ea599951f6ae3acc19 100644 --- a/tensorflow/contrib/lite/g3doc/_book.yaml +++ b/tensorflow/contrib/lite/g3doc/_book.yaml @@ -5,7 +5,7 @@ upper_tabs: # Dropdown menu - name: Ecosystem path: /ecosystem - is_default: True + is_default: true menu: - include: /ecosystem/_menu_toc.yaml lower_tabs: @@ -14,46 +14,59 @@ upper_tabs: - name: Guide contents: - title: Overview - path: /mobile/overview - - title: Developer Guide - path: /mobile/devguide - - title: Android Demo App - path: /mobile/demo_android - - title: iOS Demo App - path: /mobile/demo_ios + path: /lite/overview + - title: Developer guide + path: /lite/devguide + - title: Android demo app + path: /lite/demo_android + - title: iOS demo app + path: /lite/demo_ios - title: Performance - path: /mobile/performance - - break: True + path: /lite/performance + - break: true - title: TensorFlow Lite APIs - path: /mobile/apis + path: /lite/apis - title: Custom operators - path: /mobile/custom_operators - - title: TensorFlow Lite Ops Versioning - path: /mobile/ops_versioning - - title: TensorFlow Lite Compatibility Guide - path: /mobile/tf_ops_compatibility - - title: List of Hosted Models - path: /mobile/models + path: /lite/custom_operators + - title: TensorFlow Lite ops versioning + path: /lite/ops_versioning + - title: TensorFlow Lite compatibility guide + path: /lite/tf_ops_compatibility + - title: List of hosted models + path: /lite/models - title: TensorFlow Lite for iOS - path: /mobile/ios + path: /lite/ios - title: TensorFlow Lite for Raspberry Pi - path: /mobile/rpi + path: /lite/rpi + - heading: TFLite Converter + - title: Overview + path: /lite/tflite_convert/ + - title: Python API + path: /lite/tflite_convert/python_api + - title: Command Line Examples + path: /lite/tflite_convert/cmdline_examples + - title: Command Line Reference + path: /lite/tflite_convert/cmdline_reference - - heading: TF Mobile + - title: TF Mobile + style: accordion status: deprecated - - title: Overview - path: /mobile/tfmobile/ - - title: Building TensorFlow on Android - path: /mobile/tfmobile/android_build - - title: Building TensorFlow on IOS - path: /mobile/tfmobile/ios_build - - title: Integrating TensorFlow libraries - path: /mobile/tfmobile/linking_libs - - title: Preparing models for mobile deployment - path: /mobile/tfmobile/prepare_models - - title: Optimizing for mobile - path: /mobile/tfmobile/optimizing + section: + - title: Overview + path: /lite/tfmobile/ + - title: Building TensorFlow on Android + path: /lite/tfmobile/android_build + - title: Building TensorFlow on IOS + path: /lite/tfmobile/ios_build + - title: Integrating TensorFlow libraries + path: /lite/tfmobile/linking_libs + - title: Preparing models for mobile deployment + path: /lite/tfmobile/prepare_models + - title: Optimizing for mobile + path: /lite/tfmobile/optimizing - name: API + skip_translation: true contents: - - include: /mobile/api_docs/python/_toc.yaml + - title: API + path: /api_docs/python/tf/contrib/lite diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml index 9119e49117ffbda268f36324072d30ffd83c9e6c..bc66cc5dc1606537b7e186f3c825ab8335aa9e91 100644 --- a/tensorflow/contrib/lite/g3doc/_index.yaml +++ b/tensorflow/contrib/lite/g3doc/_index.yaml @@ -1,59 +1,209 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml +project_path: /lite/_project.yaml +book_path: /lite/_book.yaml description: landing_page: + custom_css_path: /site-assets/css/style.css rows: - - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices. + - heading: TensorFlow Lite is for mobile and embedded devices. + description: > +

+ TensorFlow Lite is the official solution for running machine learning + models on mobile and embedded devices. It enables on‑device machine + learning inference with low latency and a small binary size on Android, + iOS, and other operating systems. +

+ + + - classname: tfo-landing-row-heading tfo-landing-row-heading-list + heading: Many benefits + description: > + On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these: items: - - description: > - TensorFlow Lite is TensorFlow’s lightweight solution for mobile and - embedded devices. It enables on-device machine learning inference with - low latency and a small binary size. TensorFlow Lite also supports - hardware acceleration with the - Android Neural Networks API. - list: - - heading: Key point 1 + - list: + - heading: Performance + description: > + TF Lite is fast with no noticeable accuracy loss—see the metrics. + icon: + icon_name: lens + foreground: theme + - heading: Portability description: > - [high-level overview] + Android, + iOS, and more specialized IoT devices. icon: - icon_name: chevron_right + icon_name: lens foreground: theme - background: grey - - heading: Key point 2 + - list: + - heading: Low latency description: > - [high-level overview] + Optimized float- and fixed-point CPU kernels, op‑fusing, and more. icon: - icon_name: chevron_right + icon_name: lens foreground: theme - background: grey - - heading: Key point 3 + - heading: Acceleration description: > - [high-level overview] + Integration with GPU and internal/external accelerators. icon: - icon_name: chevron_right + icon_name: lens foreground: theme - background: grey - - code_block: | -
-        $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
-               --input_format=TENSORFLOW_GRAPHDEF \
-               --output_format=TFLITE \
-               --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
-               --inference_type=FLOAT \
-               --input_type=FLOAT \
-               --input_arrays=input \
-               --output_arrays=MobilenetV1/Predictions/Reshape_1 \
-               --input_shapes=1,224,224,3
-        
+ - list: + - heading: Small model size + description: > + Controlled dependencies, quantization, + and op registration. + icon: + icon_name: lens + foreground: theme + - heading: Tooling + description: > + Conversion, compression, benchmarking, power-consumption, and more. + icon: + icon_name: lens + foreground: theme + + - classname: devsite-landing-row-logos tfo-landing-row-heading + heading: Companies using TensorFlow Lite + items: + - custom_image: + path: ./images/landing-page/photos_logo.png + path: https://www.photos.google.com + - custom_image: + path: ./images/landing-page/gboard_logo.png + path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US + - custom_image: + path: ./images/landing-page/gmail_logo.png + path: https://www.google.com/gmail/ + - custom_image: + path: ./images/landing-page/assistant_logo.png + path: https://assistant.google.com/ + + - classname: devsite-landing-row-logos + items: + - custom_image: + path: ./images/landing-page/vsco_logo.png + path: https://vsco.co + - custom_image: + path: ./images/landing-page/shazam_logo.png + path: https://www.shazam.com/ + - custom_image: + path: ./images/landing-page/nest_logo.png + path: https://nest.com/ + - custom_image: + path: ./images/landing-page/loseit_logo.png + path: https://www.loseit.com/ + + - classname: devsite-landing-row-no-image-background devsite-landing-row-67 + background: grey + items: + - description: > + “TensorFlow Lite helped us introduce machine learning and AI into our + app in an easy and streamlined way. We could reduce the size of our + models while keeping the accuracy high. This helped us create an amazing + fishing experience for our users by allowing them to identify any fish + species with just a photo.” + image_path: ./images/landing-page/fishbrain_logo_big.png + + - heading: How it works + items: + - heading: Build + icon: + icon_name: build + description: > + Build a new model or retrain an existing one, such as using transfer learning. + buttons: + - label: Read the developer guide + path: /lite/devguide + classname: button button-primary tfo-button-primary + - heading: Convert + icon: + icon_name: autorenew + description: > + Convert a TensorFlow model into a compressed flat buffer with the + TensorFlow Lite Optimizing Converter (TOCO). + buttons: + - label: Read the TOCO guide + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md + classname: button button-primary tfo-button-primary + - heading: Deploy + icon: + icon_name: bolt + description: > + Take the compressed .tflite file and load it into a mobile + or embedded device.
+ See the tutorials below to build an app. + + - heading: Build your first TensorFlow Lite app + background: grey + items: + - classname: tfo-landing-row-item-inset-white + heading: Get started + description: > + + - classname: tfo-landing-row-item-inset-white + heading: Share your TensorFlow Lite story + description: > + We love to hear what you're working on—it may even get highlighted on + our social media! Tell us. + + - classname: devsite-landing-row-no-image-background devsite-landing-row-67 + items: + - description: > +

+ “The release of TensorFlow Lite has allowed us to deploy an engaging + real-time experience to our users that eliminates the requirement + for a data connection. TensorFlow Lite’s ability to compress and + optimize the TensorFlow graph for mobile deployment has been + transformative in expanding the capabilities of Snap It. +

+

+ Through TensorFlow Lite, our users can now enjoy a state of the + art, computer-vision-based food logging experience without worrying + about signal strength. We look forward to future collaborations + with the TensorFlow Lite team.” +

+ image_path: ./images/landing-page/loseit_logo_big.png - classname: devsite-landing-row-cards + background: grey + heading: Updates items: + - heading: Introducing the Model Optimization Toolkit + image_path: /ecosystem/images/tf-logo-card-16x9.png + path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3 + buttons: + - label: Read on TensorFlow blog + path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3 + - heading: East Africa Cassava App + image_path: ./images/landing-page/detect_crop_disease_in_africa.png + path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5 + buttons: + - label: Read more + path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5 - heading: Using TensorFlow Lite on Android image_path: /ecosystem/images/tf-logo-card-16x9.png path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d buttons: - label: Read on TensorFlow blog path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d + + - classname: devsite-landing-row-cards + background: grey + items: - heading: TensorFlow Lite at the Dev Summit youtube_id: FAMfy7izB6A buttons: @@ -65,3 +215,4 @@ landing_page: buttons: - label: View on GitHub path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite + - classname: devsite-landing-row-item-hidden diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml index b39666516baab42d289e4d40077c2877ed65d396..3ce698639647d9e105b6748512314aeca148b0a0 100644 --- a/tensorflow/contrib/lite/g3doc/_project.yaml +++ b/tensorflow/contrib/lite/g3doc/_project.yaml @@ -1,10 +1,10 @@ name: TensorFlow Lite -breadcrumb_name: Mobile -home_url: /mobile/ +breadcrumb_name: TensorFlow Lite +home_url: /lite/ parent_project_metadata_path: /_project.yaml description: > TensorFlow Lite is a lightweight solution for mobile and embedded devices. -use_site_branding: True -hide_from_products_list: True +use_site_branding: true +hide_from_products_list: true content_license: cc3-apache2 buganizer_id: 316308 diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml deleted file mode 100644 index 1e1c44c6929571144d8cf0b54463c48e37466022..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# Automatically generated file; please do not edit -toc: - - title: TensorFlow Lite - section: - - title: Overview - path: /mobile/api_docs/python/ diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md index 90e7915c52cecc7fff108cbe829aaa97b0fc4ce3..0eed5160009c07727f0c2985ebe963efc7bb9d8e 100644 --- a/tensorflow/contrib/lite/g3doc/devguide.md +++ b/tensorflow/contrib/lite/g3doc/devguide.md @@ -1,5 +1,4 @@ - -# Developer Guide +# TF Lite Developer Guide Using a TensorFlow Lite model in your mobile app requires multiple considerations: you must choose a pre-trained or custom model, convert the model @@ -55,7 +54,7 @@ both floating point and quantized inference. ### Train a custom model A developer may choose to train a custom model using Tensorflow (see the -[TensorFlow tutorials](../../tutorials/) for examples of building and training +[TensorFlow tutorials](../tutorials/) for examples of building and training models). If you have already written a model, the first step is to export this to a `tf.GraphDef` file. This is required because some formats do not store the model structure outside the code, and we must communicate with other parts of the @@ -205,7 +204,7 @@ The open source Android demo app uses the JNI interface and is available [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). You can also download a [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). -See the Android demo guide for details. +See the Android demo guide for details. The Android mobile guide has instructions for installing TensorFlow on Android and setting up `bazel` and Android Studio. @@ -214,7 +213,7 @@ installing TensorFlow on Android and setting up `bazel` and Android Studio. To integrate a TensorFlow model in an iOS app, see the [TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) -guide and iOS demo guide. +guide and iOS demo guide. #### Core ML support diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..ced0872ab2e69768cc3d1b759032a8ed7ece2149 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png new file mode 100644 index 0000000000000000000000000000000000000000..45b3b4f6fe9ce69508d488f761e29f90c4304040 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..bc1bf6e1e719adb41b08c967d5adc2b7839d9453 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png new file mode 100644 index 0000000000000000000000000000000000000000..d76fca86a92d4b77d529e3572acaa6198a986b86 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..f1a93ab76307168eff28fdb08d4780b2e2cf5ff8 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..21aa2c84ea56a1c627eace2be610e3df69468450 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..b6b3d14df994748c37953fc67ebc6bc6d62e2607 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png new file mode 100644 index 0000000000000000000000000000000000000000..b3e46d4bd8c1bc3d2c5165c182c04e18db659c68 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..35bfd97373279a3a5a1f8f622d3358ecbfce10f2 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..4333426dfe008e399786c19c4a312693230860a4 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..6ec412c75c51d0e6f7cfffa011d250907731c12e Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..f408f9024b3036df362a9c792b5af1dc99ae939b Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png new file mode 100644 index 0000000000000000000000000000000000000000..44d0ccd3128dea1c947e57ccbc4e18b2d34cef88 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png differ diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png new file mode 100644 index 0000000000000000000000000000000000000000..94a6310612828db2370d19a094795341478e90f8 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png differ diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index a83d2c8fec7c9638bbdebd851fec74a46b624553..3b9fcca8117dc1859d075ae5f048cfc9f0d988a3 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -1,5 +1,10 @@ -# TensorFlow Lite for iOS +# Build TensorFlow Lite for iOS + +This document describes how to build TensorFlow Lite iOS library. If you just +want to use it, the easiest way is using the TensorFlow Lite CocoaPod releases. +See [TensorFlow Lite iOS Demo](demo_ios.md) for examples. + ## Building diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index a4267eee4cca6db20b95a6a696f1f3373aabdd54..279764ce964e523c769addda2b477690694dc048 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,6 +1,23 @@ # List of Hosted Models +# AutoML mobile image classification models (Float Models) + +Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^ +------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------: +MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms +MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms +MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms +MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms +MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms +MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms +MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms +MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms +MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms + +^ Performance numbers are generated on Pixel-1 using single thread large BIG core. + + ## Image classification (Float Models) Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md index 8cf43496dfef351cb094db9c9355b280d112e2fa..9d035a69211d7ced913e6d16061c6ad8ca912e64 100644 --- a/tensorflow/contrib/lite/g3doc/overview.md +++ b/tensorflow/contrib/lite/g3doc/overview.md @@ -25,7 +25,7 @@ models. TensorFlow Lite defines a new model file format, based on [FlatBuffers](https://google.github.io/flatbuffers/). FlatBuffers is an -open-sourced, efficient cross platform serialization library. It is similar to +efficient open-source cross-platform serialization library. It is similar to [protocol buffers](https://developers.google.com/protocol-buffers/?hl=en), but the primary difference is that FlatBuffers does not need a parsing/unpacking step to a secondary representation before you can access data, often coupled diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md index 28cb6aba6ec61d12d86e078e47665833df8afec7..ed114527166da79dba2d92c3ffad78e9885f9e94 100644 --- a/tensorflow/contrib/lite/g3doc/performance.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -1,174 +1,45 @@ -# Performance - -This document lists TensorFlow Lite performance benchmarks when running well -known models on some Android and iOS devices. - -These performance benchmark numbers were generated with the -[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) -and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). - -# Android performance benchmarks - -For Android benchmarks, the CPU affinity is set to use big cores on the device to -reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). - -It assumes that models were download and unzipped to the -`/data/local/tmp/tflite_models` directory. The benchmark binary is built -using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) -and assumed in the `/data/local/tmp` directory. - -To run the benchmark: - -``` -adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ - --num_threads=1 \ - --graph=/data/local/tmp/tflite_models/${GRAPH} \ - --warmup_runs=1 \ - --num_runs=50 \ - --use_nnapi=false -``` - -Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity -chosen according to the following table: - -Device | CPU_MASK | --------| ---------- -Pixel 2 | f0 | -Pixel xl | 0c | - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model NameDevice Mean inference time (std dev)
- Mobilenet_1.0_224(float) - Pixel 2 166.5 ms (2.6 ms)
Pixel xl 122.9 ms (1.8 ms)
- Mobilenet_1.0_224 (quant) - Pixel 2 69.5 ms (0.9 ms)
Pixel xl 78.9 ms (2.2 ms)
- NASNet mobile - Pixel 2 273.8 ms (3.5 ms)
Pixel xl 210.8 ms (4.2 ms)
- SqueezeNet - Pixel 2 234.0 ms (2.1 ms)
Pixel xl 158.0 ms (2.1 ms)
- Inception_ResNet_V2 - Pixel 2 2846.0 ms (15.0 ms)
Pixel xl 1973.0 ms (15.0 ms)
- Inception_V4 - Pixel 2 3180.0 ms (11.7 ms)
Pixel xl 2262.0 ms (21.0 ms)
- -# iOS benchmarks - -To run iOS benchmarks, the [benchmark -app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) -was modified to include the appropriate model and `benchmark_params.json` was -modified to set `num_threads` to 1. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model NameDevice Mean inference time (std dev)
- Mobilenet_1.0_224(float) - iPhone 8 32.2 ms (0.8 ms)
- Mobilenet_1.0_224 (quant) - iPhone 8 24.4 ms (0.8 ms)
- NASNet mobile - iPhone 8 60.3 ms (0.6 ms)
- SqueezeNet - iPhone 8 44.3 (0.7 ms)
- Inception_ResNet_V2 - iPhone 8562.4 ms (18.2 ms)
- Inception_V4 - iPhone 8 661.0 ms (29.2 ms)
+# Performance best practices + +Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite. + +## Choose the best model for the task +Depending on the task you will need to make a tradeoff between model complexity and size. If your task requires high accuracy then you may need a large and complex model. Some tasks may work with a less precise model, for these tasks it is better to use a smaller but less precise model. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. For example, graphs below show accuracy and latency tradeoff for some common image classification models. + +![accuracy vs model size](images/performance/model_size_vs_accuracy.png "Accuracy vs Model size") + + +![latency vs model size](images/performance/model_size_vs_latency.png "Latency vs Model size") + +One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices. + +You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for +[image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and + [object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193). + + +## Profile your model +Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time. + +## Profile and optimize operators in the graph +If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator. + This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md). + +## Quantize your model +If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. + +## Tweak the number of threads +Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads. Multi-threaded execution however comes at the cost of increased performance variability depending on what else is been executed concurrently. This is particularly the case for mobile apps. For example, isolated tests may show 2x speed up vs single-threaded but if another app is executing at the same time may result in worst performance than single-threaded. + +## Eliminate redundant copies +If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151). + +## Profile your application with platform specific tools +Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform. + +## Evaluate whether your model benefits from using hardware accelerators available on the device +Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android. +You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance. + +## Need more help +The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue. diff --git a/tensorflow/contrib/lite/g3doc/performance_benchmarks.md b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md new file mode 100644 index 0000000000000000000000000000000000000000..28cb6aba6ec61d12d86e078e47665833df8afec7 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md @@ -0,0 +1,174 @@ + +# Performance + +This document lists TensorFlow Lite performance benchmarks when running well +known models on some Android and iOS devices. + +These performance benchmark numbers were generated with the +[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) +and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +# Android performance benchmarks + +For Android benchmarks, the CPU affinity is set to use big cores on the device to +reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). + +It assumes that models were download and unzipped to the +`/data/local/tmp/tflite_models` directory. The benchmark binary is built +using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) +and assumed in the `/data/local/tmp` directory. + +To run the benchmark: + +``` +adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ + --num_threads=1 \ + --graph=/data/local/tmp/tflite_models/${GRAPH} \ + --warmup_runs=1 \ + --num_runs=50 \ + --use_nnapi=false +``` + +Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity +chosen according to the following table: + +Device | CPU_MASK | +-------| ---------- +Pixel 2 | f0 | +Pixel xl | 0c | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + Pixel 2 166.5 ms (2.6 ms)
Pixel xl 122.9 ms (1.8 ms)
+ Mobilenet_1.0_224 (quant) + Pixel 2 69.5 ms (0.9 ms)
Pixel xl 78.9 ms (2.2 ms)
+ NASNet mobile + Pixel 2 273.8 ms (3.5 ms)
Pixel xl 210.8 ms (4.2 ms)
+ SqueezeNet + Pixel 2 234.0 ms (2.1 ms)
Pixel xl 158.0 ms (2.1 ms)
+ Inception_ResNet_V2 + Pixel 2 2846.0 ms (15.0 ms)
Pixel xl 1973.0 ms (15.0 ms)
+ Inception_V4 + Pixel 2 3180.0 ms (11.7 ms)
Pixel xl 2262.0 ms (21.0 ms)
+ +# iOS benchmarks + +To run iOS benchmarks, the [benchmark +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +was modified to include the appropriate model and `benchmark_params.json` was +modified to set `num_threads` to 1. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + iPhone 8 32.2 ms (0.8 ms)
+ Mobilenet_1.0_224 (quant) + iPhone 8 24.4 ms (0.8 ms)
+ NASNet mobile + iPhone 8 60.3 ms (0.6 ms)
+ SqueezeNet + iPhone 8 44.3 (0.7 ms)
+ Inception_ResNet_V2 + iPhone 8562.4 ms (18.2 ms)
+ Inception_V4 + iPhone 8 661.0 ms (29.2 ms)
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 8660d29855899c110df9dd1746d0e6f1075f21e5..b0dfb0fed1f7a072487a06c11bddf5545911ffdf 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -866,6 +866,17 @@ Outputs { } ``` +**ZEROS_LIKE** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: A tensor of the same shape and type as x but filled with zeros +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md similarity index 78% rename from tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md rename to tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md index 84680b968e87275b5f26c9a6dbab0ff41ebd505b..d88acfae80a8af9755d212973784194a99559097 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md @@ -1,57 +1,33 @@ -# TensorFlow Lite Optimizing Converter command-line examples - -This page provides examples on how to use TOCO via command line. It is -complemented by the following documents: - -* [README](../README.md) -* [Command-line glossary](cmdline_reference.md) -* [Python API examples](python_api.md) - -Table of contents: - -* [Command-line tools](#tools) - * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) -* [Basic examples](#basic) - * [Convert a TensorFlow GraphDef](#graphdef) - * [Convert a TensorFlow SavedModel](#savedmodel) - * [Convert a tf.keras model](#keras) -* [Quantization](#quantization) - * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) - * [Use "dummy-quantization" to try out quantized inference on a float - graph](#dummy-quant) -* [Specifying input and output arrays](#specifying-input-and-output-arrays) - * [Multiple input arrays](#multiple-input-arrays) - * [Multiple output arrays](#multiple-output-arrays) - * [Specifying subgraphs](#specifying-subgraphs) -* [Graph visualizations](#graph-visualizations) - * [Using --output_format=GRAPHVIZ_DOT](#using-output-format-graphviz-dot) - * [Using --dump_graphviz_dir](#using-dump-graphviz-dir) - * [Graph "video" logging](#graph-video-logging) - * [Legend for the graph visualizations](#graphviz-legend) +# TensorFlow Lite Converter command-line examples + +This page shows how to use the TensorFlow Lite Converter in the command line. + +[TOC] ## Command-line tools -There are two approaches to running TOCO via command line. +There are two approaches to running the converter in the command line. * `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool - `tflite_convert` will be installed as part of the Python package. All of the + `tflite_convert` is installed as part of the Python package. All of the examples below use `tflite_convert` for simplicity. * Example: `tflite_convert --output_file=...` -* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow - repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository) - and use `bazel`. This is the recommended approach for converting models that - utilize new features that were not supported by TOCO in TensorFlow 1.9. +* `bazel`: In order to run the latest version of the TensorFlow Lite Converter + either install the nightly build using + [pip](https://www.tensorflow.org/install/pip) or + [clone the TensorFlow repository](https://www.tensorflow.org/install/source) + and use `bazel`. * Example: `bazel run //tensorflow/contrib/lite/python:tflite_convert -- --output_file=...` -### Converting models prior to TensorFlow 1.9. +### Converting models prior to TensorFlow 1.9 -The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python -API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the -`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in -Terminal for additional details on the command-line flags available. There were -no command line tools in TensorFlow 1.8. +The recommended approach for using the converter prior to TensorFlow 1.9 is the +[Python API](python_api.md#pre-tensorflow-1.9). If a command line tool is +desired, the `toco` command line tool was available in TensorFlow 1.7. Enter +`toco --help` in Terminal for additional details on the command-line flags +available. There were no command line tools in TensorFlow 1.8. ## Basic examples @@ -117,9 +93,9 @@ tflite_convert \ ### Convert a TensorFlow GraphDef for quantized inference -TOCO is compatible with fixed point quantization models described -[here](https://www.tensorflow.org/performance/quantization). These are float -models with +The TensorFlow Lite Converter is compatible with fixed point quantization models +described [here](https://www.tensorflow.org/performance/quantization). These are +float models with [`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) ops inserted at the boundaries of fused layers to record min-max range information. This generates a quantized inference workload that reproduces the @@ -141,12 +117,12 @@ tflite_convert \ ### Use \"dummy-quantization\" to try out quantized inference on a float graph -In order to evaluate the possible benefit of generating a quantized graph, TOCO -allows "dummy-quantization" on float graphs. The flags `--default_ranges_min` -and `--default_ranges_max` accept plausible values for the min-max ranges of the -values in all arrays that do not have min-max information. "Dummy-quantization" -will produce lower accuracy but will emulate the performance of a correctly -quantized model. +In order to evaluate the possible benefit of generating a quantized graph, the +converter allows "dummy-quantization" on float graphs. The flags +`--default_ranges_min` and `--default_ranges_max` accept plausible values for +the min-max ranges of the values in all arrays that do not have min-max +information. "Dummy-quantization" will produce lower accuracy but will emulate +the performance of a correctly quantized model. The example below contains a model using Relu6 activation functions. Therefore, a reasonable guess is that most activation ranges should be contained in [0, 6]. @@ -207,10 +183,10 @@ tflite_convert \ ### Specifying subgraphs Any array in the input file can be specified as an input or output array in -order to extract subgraphs out of an input graph file. TOCO discards the parts -of the graph outside of the specific subgraph. Use [graph -visualizations](#graph-visualizations) to identify the input and output arrays -that make up the desired subgraph. +order to extract subgraphs out of an input graph file. The TensorFlow Lite +Converter discards the parts of the graph outside of the specific subgraph. Use +[graph visualizations](#graph-visualizations) to identify the input and output +arrays that make up the desired subgraph. The follow command shows how to extract a single fused layer out of a TensorFlow GraphDef. @@ -247,9 +223,10 @@ function tends to get fused). ## Graph visualizations -TOCO can export a graph to the Graphviz Dot format for easy visualization via -either the `--output_format` flag or the `--dump_graphviz_dir` flag. The -subsections below outline the use cases for each. +The converter can export a graph to the Graphviz Dot format for easy +visualization using either the `--output_format` flag or the +`--dump_graphviz_dir` flag. The subsections below outline the use cases for +each. ### Using `--output_format=GRAPHVIZ_DOT` @@ -323,10 +300,23 @@ As before, these can be rendered to PDFs: dot -Tpdf -O /tmp/toco_*.dot ``` -Sample output files can be seen here: - -* [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) -* [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). +Sample output files can be seen here below. Note that it is the same +`AveragePool` node in the top right of each image. + + + + + + +
+ + + + + + + +
beforeafter
### Graph "video" logging @@ -345,7 +335,7 @@ change was introduced in the graph. * Some typically heavy operators (e.g. Conv) are rendered in a darker red. -* Arrays are octogons with the following colors: +* Arrays are octagons with the following colors: * Constant arrays are blue. * Activation arrays are gray: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_reference.md similarity index 91% rename from tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md rename to tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_reference.md index 00bc8d4ccb8aedcfe701377419e6cd41d0b59855..d65912fea61ccece1db9e184d9e07264598326a2 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_reference.md @@ -1,19 +1,10 @@ -# TensorFlow Lite Optimizing Converter command-line glossary +# TensorFlow Lite Converter command-line glossary -This page is complete reference of command-line flags used by TOCO's command -line starting from TensorFlow 1.9 up until the most recent build of TensorFlow. -It is complemented by the following other documents: +This page is complete reference of command-line flags used by the TensorFlow +Lite Converter's command line starting from TensorFlow 1.9 up until the most +recent build of TensorFlow. -* [README](../README.md) -* [Command-line examples](cmdline_examples.md) -* [Python API examples](python_api.md) - -Table of contents: - -* [High-level flags](#high-level-flags) -* [Model flags](#model-flags) -* [Transformation flags](#transformation-flags) -* [Logging flags](#logging-flags) +[TOC] ## High-level flags @@ -32,7 +23,7 @@ files. The flag `--output_file` is always required. Additionally, either * `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of the output file. Allowed values: * `TFLITE`: TensorFlow Lite FlatBuffer format. - * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the + * `GRAPHVIZ_DOT`: GraphViz `.dot` format containing a visualization of the graph after graph transformations. * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss of TFLite specific transformations. Therefore, the resulting @@ -68,7 +59,7 @@ based on index. * `--input_shapes`. Type: colon-separated list of comma-separated lists of integers. Each comma-separated list of integers gives the shape of one of the input arrays specified in - [TensorFlow convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape). + [TensorFlow convention](https://www.tensorflow.org/guide/dims_types#shape). * Example: `--input_shapes=1,60,80,3` for a typical vision model means a batch size of 1, an input image height of 60, an input image width of 80, and an input image depth of 3 (representing RGB channels). diff --git a/tensorflow/contrib/lite/g3doc/tflite_convert/index.md b/tensorflow/contrib/lite/g3doc/tflite_convert/index.md new file mode 100644 index 0000000000000000000000000000000000000000..12ba0225f62714af3c5198e1ba31890615313c72 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/tflite_convert/index.md @@ -0,0 +1,22 @@ +# TensorFlow Lite Converter + +The TensorFlow Lite Converter converts TensorFlow graphs into +TensorFlow Lite graphs. There are additional usages that are also detailed in +the usage documentation. + + +## Where the converter fits in the TensorFlow landscape + +Once an application developer has a trained TensorFlow model, the TensorFlow +Lite Converter will accept +that model and generate a TensorFlow Lite +[FlatBuffer](https://google.github.io/flatbuffers/) file. The converter currently supports +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. + +![drawing](toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/g3doc/tflite_convert/python_api.md similarity index 65% rename from tensorflow/contrib/lite/toco/g3doc/python_api.md rename to tensorflow/contrib/lite/g3doc/tflite_convert/python_api.md index 51f808d4f07ee33188c34d408c2829aa8bc8f406..e1c0e0c2409b1d74744af8b02fea097043c5e37f 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/g3doc/tflite_convert/python_api.md @@ -1,55 +1,36 @@ -# TensorFlow Lite Optimizing Converter & Interpreter Python API reference - -This page provides examples on how to use TOCO and the TensorFlow Lite -interpreter via the Python API. It is complemented by the following documents: - -* [README](../README.md) -* [Command-line examples](cmdline_examples.md) -* [Command-line glossary](cmdline_reference.md) - -Table of contents: - -* [High-level overview](#high-level-overview) -* [API](#api) -* [Basic examples](#basic) - * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) - * [Exporting a GraphDef from file](#basic-graphdef-file) - * [Exporting a SavedModel](#basic-savedmodel) - * [Exporting a tf.keras File](#basic-keras-file) -* [Complex examples](#complex) - * [Exporting a quantized GraphDef](#complex-quant) -* [TensorFlow Lite Python interpreter](#interpreter) - * [Using the interpreter from a model file](#interpreter-file) - * [Using the interpreter from model data](#interpreter-data) -* [Additional instructions](#additional-instructions) - * [Build from source code](#latest-package) - * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) +# TensorFlow Lite Converter & Interpreter Python API reference + +This page provides examples on how to use the TensorFlow Lite Converter and the +TensorFlow Lite interpreter using the Python API. + +[TOC] + ## High-level overview -While the TensorFlow Lite Optimizing Converter can be used from the command -line, it is often convenient to use it as part of a Python model build and -training script. This is so that conversion can be part of your model -development pipeline. This allows you to know early and often that you are -designing a model that can be targeted to devices with mobile. +While the TensorFlow Lite Converter can be used from the command line, it is +often convenient to use in a Python script as part of the model development +pipeline. This allows you to know early that you are designing a model that can +be targeted to devices with mobile. ## API The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 -is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is -`tf.contrib.lite.Interpreter`. - -`TocoConverter` provides class methods based on the original format of the -model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. -`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. -Example usages for simple float-point models are shown in [Basic -Examples](#basic). Examples usages for more complex models is shown in [Complex -Examples](#complex). - -**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python -interpreter when the conversion fails. This will be remedied as soon as -possible. +is `tf.contrib.lite.TFLiteConverter`. The API for calling the Python intepreter +is `tf.contrib.lite.Interpreter`. + +Note: Reference "Additional Instructions" sections for converting TensorFlow +models to TensorFlow Lite +[in TensorFlow 1.9 to TensorFlow 1.11](#pre-tensorflow-1.11) and +[prior to TensorFlow 1.9](#pre-tensorflow-1.9) + +`TFLiteConverter` provides class methods based on the original format of the +model. `TFLiteConverter.from_session()` is available for GraphDefs. +`TFLiteConverter.from_saved_model()` is available for SavedModels. +`TFLiteConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in +[Basic Examples](#basic). Examples usages for more complex models is shown in +[Complex Examples](#complex). ## Basic examples @@ -71,7 +52,7 @@ out = tf.identity(val, name="out") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -84,7 +65,7 @@ TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and The example uses [Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). -The function only supports GraphDefs frozen via +The function only supports GraphDefs frozen using [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). ```python @@ -94,7 +75,7 @@ graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" input_arrays = ["input"] output_arrays = ["MobilenetV1/Predictions/Softmax"] -converter = tf.contrib.lite.TocoConverter.from_frozen_graph( +converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) @@ -108,25 +89,26 @@ FlatBuffer. ```python import tensorflow as tf -converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +converter = tf.contrib.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` For more complex SavedModels, the optional parameters that can be passed into -`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`, +`TFLiteConverter.from_saved_model()` are `input_arrays`, `input_shapes`, `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are -available by running `help(tf.contrib.lite.TocoConverter)`. +available by running `help(tf.contrib.lite.TFLiteConverter)`. ### Exporting a tf.keras File The following example shows how to convert a `tf.keras` model into a TensorFlow -Lite FlatBuffer. +Lite FlatBuffer. This example requires +[`h5py`](http://docs.h5py.org/en/latest/build.html) to be installed. ```python import tensorflow as tf -converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file("keras_model.h5") tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -158,7 +140,7 @@ keras_file = "keras_model.h5" tf.keras.models.save_model(model, keras_file) # Convert to TensorFlow Lite model. -converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` @@ -168,7 +150,7 @@ open("converted_model.tflite", "wb").write(tflite_model) For models where the default value of the attributes is not sufficient, the attribute's values should be set before calling `convert()`. In order to call any constants use `tf.contrib.lite.constants.` as seen below with -`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python +`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TFLiteConverter)` in the Python terminal for detailed documentation on the attributes. Although the examples are demonstrated on GraphDefs containing only constants. @@ -188,7 +170,7 @@ val = img + const out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") with tf.Session() as sess: - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 input_arrays = converter.get_input_arrays() converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev @@ -245,7 +227,7 @@ val = img + const out = tf.identity(val, name="out") with tf.Session() as sess: - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [img], [out]) tflite_model = converter.convert() # Load TFLite model and allocate tensors. @@ -257,13 +239,20 @@ interpreter.allocate_tensors() ### Build from source code -In order to run the latest version of the TOCO Python API, clone the TensorFlow -repository, configure the installation, and build and install the pip package. -Detailed instructions are available -[here](https://www.tensorflow.org/install/install_sources). +In order to run the latest version of the TensorFlow Lite Converter Python API, +either install the nightly build with +[pip](https://www.tensorflow.org/install/pip) (recommended) or +[Docker](https://www.tensorflow.org/install/docker), or +[build the pip package from source](https://www.tensorflow.org/install/source). + +### Converting models in TensorFlow 1.9 to TensorFlow 1.11 + +To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.9 through +TensorFlow 1.11, use `TocoConverter`. `TocoConverter` is semantically +identically to `TFLiteConverter`. -### Converting models prior to TensorFlow 1.9. +### Converting models prior to TensorFlow 1.9 -To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert` -function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted -parameters. +To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.7 and TensorFlow +1.8, use the `toco_convert` function. Run `help(tf.contrib.lite.toco_convert)` +to get details about accepted parameters. diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/g3doc/tflite_convert/toco_landscape.svg similarity index 100% rename from tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg rename to tensorflow/contrib/lite/g3doc/tflite_convert/toco_landscape.svg diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index c7cdee07de375c165e01626154d92a81ad880eca..2eb776d10cf8ec68987d13b580eddf2f1bda8e78 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,6 +1,22 @@ - # Building TensorFlow on Android +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ To get you started working with TensorFlow on Android, we'll walk through two ways to build our TensorFlow mobile demos and deploying them on an Android device. The first is Android Studio, which lets you build and deploy in an @@ -93,7 +109,7 @@ requires some knowledge of build systems and Android developer tools, but we'll guide you through the basics here. - First, follow our instructions for - installing from sources. + installing from sources. This will also guide you through installing Bazel and cloning the TensorFlow code. diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md index d003bb2f3855141b51c6d4afc7fc5a46dc08d665..15f0fd396134e40e89266182cb308080d9d250cb 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,10 +1,26 @@ - # Overview +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ TensorFlow was designed to be a good deep learning solution for mobile platforms. Currently we have two solutions for deploying machine learning applications on mobile and embedded devices: TensorFlow for Mobile and -TensorFlow Lite. +TensorFlow Lite. ## TensorFlow Lite versus TensorFlow Mobile diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index be8b4100c89f4b02e651b1585faf438881c9119d..d922907cdc5fe5ccec8864b456586fce0293a0af 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,6 +1,22 @@ - # Building TensorFlow on iOS +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ ## Using CocoaPods The simplest way to get started with TensorFlow on iOS is using the CocoaPods diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index 4d4bb3bc081d613714271f8b0bf7461cb1e0f4d5..fd0e322c93493ed835ae7ec9766a708885c6ac88 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,6 +1,22 @@ - # Integrating TensorFlow libraries +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ Once you have made some progress on a model that addresses the problem you’re trying to solve, it’s important to test it out inside your application immediately. There are often unexpected differences between your training data diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index 7436594fd8580151ba66562eccd408cc7e6c4201..59ff8e774c6c63a01668aee7d6caeea01171468d 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,6 +1,22 @@ - # Optimizing for mobile +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ There are some special issues that you have to deal with when you’re trying to ship on mobile or embedded devices, and you’ll need to think about these as you’re developing your model. diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index d1c67d4c61608bcbc9b0bcee5b60f46a73b44692..1d373251ddf3ba6a0119bd57bf14caf100ef371a 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,6 +1,22 @@ - # Preparing models for mobile deployment +Warning: We expect to deprecate TensorFlow Mobile in early 2019 + +
+

+ TensorFlow Lite is our main mobile and embedded offering. We are + working hard to close the feature gap between TensorFlow Mobile and + TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We + will give ample notice to our users when we get to that point and will + provide help and support to ensure easy migrations. +

+

+ In the meantime, please use TensorFlow Lite. If you have a feature request, + such as a missing op, please post to our GitHub. +

+
+ The requirements for storing model information during training are very different from when you want to release it as part of a mobile app. This section covers the tools involved in converting from a training model to something diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 2657bcd42b364149dd860f420883407231f82636..88e41ffc55d2b666bb4837c12dccb2ebcdcaac33 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -451,16 +451,15 @@ TfLiteStatus Interpreter::AllocateTensors() { // Reset the variable tensors to zero after (re)allocating the tensors. // Developers shouldn't rely on the side effect of this function to reset - // variable tesnsors. They should call `ResetVariableTensorsToZero` directly + // variable tesnsors. They should call `ResetVariableTensors` directly // instead. - ResetVariableTensorsToZero(); + ResetVariableTensors(); return kTfLiteOk; } -// TODO(ycling): Consider to provide other functions to initialize variable -// tensors to non-zero values. -TfLiteStatus Interpreter::ResetVariableTensorsToZero() { +// TODO(ycling): Support non-zero default values. +TfLiteStatus Interpreter::ResetVariableTensors() { for (auto& tensor : tensors_) { if (!tensor.is_variable) { continue; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index aa2bc4def666eee8993c9b3198e864dd6ed642de..651a97e9dc84350569514528ae5635ec040d607f 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -349,6 +349,10 @@ class Interpreter { return context_.allow_fp32_relax_to_fp16; } + // Owning handle to a TfLiteDelegate instance. + using TfLiteDelegatePtr = + std::unique_ptr; + // Allow a delegate to look at the graph and modify the graph to handle // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. @@ -421,9 +425,12 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } - // Reset all variable tensors to zero. + // Reset all variable tensors to the default value. + // If a variable tensor doesn't have a buffer, reset it to zero. + // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it + // to the value of the buffer. // WARNING: This is an experimental API and subject to change. - TfLiteStatus ResetVariableTensorsToZero(); + TfLiteStatus ResetVariableTensors(); // Retrieve an operator's description of its work, for profiling purposes. const char* OpProfilingString(const TfLiteRegistration& op_reg, @@ -571,19 +578,11 @@ class Interpreter { TfLiteExternalContextType type, TfLiteExternalContext* ctx); - using TfLiteDelegatePtr = - std::unique_ptr; - // Variant of the public ModifyGraphWithDelegate method that additionally // Assumes ownership of the provided delegate. // WARNING: This is an experimental API and subject to change. - template - TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr typed_delegate, + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate, bool allow_dynamic_tensors = false) { - TfLiteDelegatePtr delegate(typed_delegate.release(), - [](TfLiteDelegate* delegate) { - delete static_cast(delegate); - }); // Note that we retain ownership of the delegate even if graph modification // fails, as delegate use will be in an indeterminate state at that point. owned_delegates_.push_back(std::move(delegate)); @@ -673,6 +672,7 @@ class Interpreter { // List of delegates that have been installed and are owned by this // interpreter instance. Useful if client delegate ownership is burdensome. // WARNING: This is an experimental API and subject to change. + // TODO(b/116667551): Use TfLiteExternalContext for storing state. std::vector owned_delegates_; std::unique_ptr memory_planner_; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index cdede430e29be7b18939f55a8bb06b66f1a3ea33..6c71d5a8d7bb3e275379637b151ab8f998b04f41 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test { template static TfLiteStatus ModifyGraphWithDelegate( Interpreter* interpreter, std::unique_ptr delegate) { - return interpreter->ModifyGraphWithDelegate(std::move(delegate)); + Interpreter::TfLiteDelegatePtr tflite_delegate( + delegate.release(), [](TfLiteDelegate* delegate) { + delete reinterpret_cast(delegate); + }); + return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate)); } protected: diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 098ba7e7731d833678fbd5eab9cce3f022570f23..e68cd26f8124c24cf25d07376b0ebe49da1c149b 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") +JAVA_SRCS = glob([ + "src/main/java/org/tensorflow/lite/*.java", +]) + # Building tensorflow-lite.aar including 4 variants of .so # To build an aar for release, run below command: # bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ @@ -20,28 +24,38 @@ aar_with_jni( android_library = ":tensorflowlite", ) +# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite. +aar_with_jni( + name = "tensorflow-lite-flex", + android_library = ":tensorflowlite_flex", +) + android_library( name = "tensorflowlite", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + ":tensorflowlite_native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite. +android_library( + name = "tensorflowlite_flex", + srcs = JAVA_SRCS, manifest = "AndroidManifest.xml", visibility = ["//visibility:public"], deps = [ - ":tflite_runtime", + ":tensorflowlite_native_flex", "@org_checkerframework_qual", ], ) android_library( name = "tensorflowlite_java", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, visibility = ["//visibility:public"], deps = [ "@org_checkerframework_qual", @@ -50,16 +64,23 @@ android_library( java_library( name = "tensorflowlitelib", - srcs = glob( - [ - "src/main/java/org/tensorflow/lite/*.java", - ], - ), + srcs = JAVA_SRCS, javacopts = JAVACOPTS, visibility = ["//visibility:public"], deps = [ ":libtensorflowlite_jni.so", - "//tensorflow/contrib/lite/java/src/main/native", + "@org_checkerframework_qual", + ], +) + +# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite. +java_library( + name = "tensorflowlitelib_flex", + srcs = JAVA_SRCS, + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_flex_jni.so", "@org_checkerframework_qual", ], ) @@ -72,7 +93,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -87,7 +107,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -110,7 +129,6 @@ java_test( tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", @@ -125,19 +143,37 @@ java_test( data = [ "src/testdata/add.bin", "src/testdata/mobilenet.tflite.bin", + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", ], javacopts = JAVACOPTS, tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ - ":libtensorflowlite_jni.so", ":tensorflowlitelib", "@com_google_truth", "@junit", ], ) +java_test( + name = "InterpreterFlexTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"], + data = [ + "//tensorflow/contrib/lite:testdata/multi_add_flex.bin", + ], + javacopts = JAVACOPTS, + tags = ["no_oss"], + test_class = "org.tensorflow.lite.InterpreterFlexTest", + visibility = ["//visibility:private"], + deps = [ + ":tensorflowlitelib_flex", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", @@ -164,14 +200,29 @@ filegroup( ) cc_library( - name = "tflite_runtime", + name = "tensorflowlite_native", srcs = ["libtensorflowlite_jni.so"], visibility = ["//visibility:public"], ) +cc_library( + name = "tensorflowlite_native_flex", + srcs = ["libtensorflowlite_flex_jni.so"], + visibility = ["//visibility:public"], +) + tflite_jni_binary( name = "libtensorflowlite_jni.so", deps = [ "//tensorflow/contrib/lite/java/src/main/native", ], ) + +# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite. +tflite_jni_binary( + name = "libtensorflowlite_flex_jni.so", + deps = [ + "//tensorflow/contrib/lite/delegates/flex:delegate", + "//tensorflow/contrib/lite/java/src/main/native", + ], +) diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index db837cf29edfc0ffe9950ffedc02cca1389b0fdf..360d622b1bcf5cf379987ceefc43c74b1b6ce5fb 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -3,12 +3,12 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_binary") def aar_with_jni(name, android_library): - # Generate dummy AndroidManifest.xml for dummy apk usage - # (dummy apk is generated by _dummy_app_for_so target below) - native.genrule( - name = name + "_binary_manifest_generator", - outs = [name + "_generated_AndroidManifest.xml"], - cmd = """ + # Generate dummy AndroidManifest.xml for dummy apk usage + # (dummy apk is generated by _dummy_app_for_so target below) + native.genrule( + name = name + "_binary_manifest_generator", + outs = [name + "_generated_AndroidManifest.xml"], + cmd = """ cat > $(OUTS) < $(OUTS) < EOF """, - ) + ) - # Generate dummy apk including .so files and later we extract out - # .so files and throw away the apk. - android_binary( - name = name + "_dummy_app_for_so", - manifest = name + "_generated_AndroidManifest.xml", - custom_package = "dummy.package.for.so", - deps = [android_library], - # In some platforms we don't have an Android SDK/NDK and this target - # can't be built. We need to prevent the build system from trying to - # use the target in that case. - tags = ["manual"], - ) + # Generate dummy apk including .so files and later we extract out + # .so files and throw away the apk. + android_binary( + name = name + "_dummy_app_for_so", + aapt_version = "aapt", + manifest = name + "_generated_AndroidManifest.xml", + custom_package = "dummy.package.for.so", + deps = [android_library], + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = [ + "manual", + "no_cuda_on_cpu_tap", + ], + ) - native.genrule( - name = name, - srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], - outs = [name + ".aar"], - tags = ["manual"], - cmd = """ + native.genrule( + name = name, + srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], + outs = [name + ".aar"], + tags = ["manual"], + cmd = """ cp $(location {}.aar) $(location :{}.aar) chmod +w $(location :{}.aar) origdir=$$PWD @@ -46,4 +50,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*" cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), - ) + ) diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 6a3f0651d03742239e43d178619be8670442b99e..c04b2a61942430108891c612ae410d04d373c840 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -1,4 +1,6 @@ -# TF Lite Android App +# TF Lite Android Image Classifier App Example + +A simple Android example that demonstrates image classification using the camera. ## Building in Android Studio with TensorFlow Lite AAR from JCenter. The build.gradle is configured to use TensorFlow Lite's nightly build. diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index 220d6c2159b56f6349e93132418fa0f6c69d1ab3..5ad738389eb8bc1d875fc888c1336fb3fa140eee 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0 android_binary( name = "TfLiteCameraDemo", srcs = glob(["java/**/*.java"]), + aapt_version = "aapt", assets = [ "//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 4f5662bc2d15f1bf6bfec0b9ec79b09f9e124186..3596e4201150abaecc1cd8fdd736510a0afc97bb 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -58,9 +58,9 @@ import android.view.View; import android.view.ViewGroup; import android.widget.CompoundButton; import android.widget.NumberPicker; -import android.widget.ToggleButton; import android.widget.TextView; import android.widget.Toast; +import android.widget.ToggleButton; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -305,22 +305,24 @@ public class Camera2BasicFragment extends Fragment textView = (TextView) view.findViewById(R.id.text); toggle = (ToggleButton) view.findViewById(R.id.button); - toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { - public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { - classifier.setUseNNAPI(isChecked); - } - }); + toggle.setOnCheckedChangeListener( + new CompoundButton.OnCheckedChangeListener() { + public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { + backgroundHandler.post(() -> classifier.setUseNNAPI(isChecked)); + } + }); np = (NumberPicker) view.findViewById(R.id.np); np.setMinValue(1); np.setMaxValue(10); np.setWrapSelectorWheel(true); - np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() { - @Override - public void onValueChange(NumberPicker picker, int oldVal, int newVal){ - classifier.setNumThreads(newVal); - } - }); + np.setOnValueChangedListener( + new NumberPicker.OnValueChangeListener() { + @Override + public void onValueChange(NumberPicker picker, int oldVal, int newVal) { + backgroundHandler.post(() -> classifier.setNumThreads(newVal)); + } + }); } /** Load the model and labels. */ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index 7bb6afd9d8b77159bb180fad6bbe43ca454f9d14..2d11a57434be98b1b3a7ff398b5ff2ca66df878d 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -59,9 +59,15 @@ public abstract class ImageClassifier { private static final int DIM_PIXEL_SIZE = 3; - /* Preallocated buffers for storing image data in. */ + /** Preallocated buffers for storing image data in. */ private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** The loaded TensorFlow Lite model. */ + private MappedByteBuffer tfliteModel; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ protected Interpreter tflite; @@ -89,7 +95,8 @@ public abstract class ImageClassifier { /** Initializes an {@code ImageClassifier}. */ ImageClassifier(Activity activity) throws IOException { - tflite = new Interpreter(loadModelFile(activity)); + tfliteModel = loadModelFile(activity); + tflite = new Interpreter(tfliteModel, tfliteOptions); labelList = loadLabelList(activity); imgData = ByteBuffer.allocateDirect( @@ -150,20 +157,28 @@ public abstract class ImageClassifier { } } + private void recreateInterpreter() { + if (tflite != null) { + tflite.close(); + tflite = new Interpreter(tfliteModel, tfliteOptions); + } + } + public void setUseNNAPI(Boolean nnapi) { - if (tflite != null) - tflite.setUseNNAPI(nnapi); + tfliteOptions.setUseNNAPI(nnapi); + recreateInterpreter(); } - public void setNumThreads(int num_threads) { - if (tflite != null) - tflite.setNumThreads(num_threads); + public void setNumThreads(int numThreads) { + tfliteOptions.setNumThreads(numThreads); + recreateInterpreter(); } /** Closes tflite to release resources. */ public void close() { tflite.close(); tflite = null; + tfliteModel = null; } /** Reads label list from Assets. */ diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 781289ceb239326ebc6b16cb6dff44ca1771ef72..ea9b9ed4b66a601981f4c402f7f8a4f6749e07fd 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +# Build targets for OVIC classification. java_test( name = "OvicClassifierTest", size = "medium", @@ -44,8 +45,10 @@ java_binary( android_library( name = "ovicbenchmarkerlib", srcs = [ + "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java", + "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java", "src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + "src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java", ], manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", tags = ["no_oss"], @@ -59,8 +62,8 @@ android_library( java_library( name = "ovicbenchmarkerlib_java", srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java", "src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], javacopts = JAVACOPTS, tags = ["no_oss"], @@ -72,3 +75,58 @@ java_library( "@org_checkerframework_qual", ], ) + +# Build targets for OVIC detection. +java_test( + name = "OvicDetectorTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicDetectorTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "@tflite_mobilenet_ssd_quant//:detect.tflite", + ], + javacopts = JAVACOPTS, + tags = ["no_oss"], + test_class = "org.tensorflow.ovic.OvicDetectorTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +android_library( + name = "ovicdetectionbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/BoundingBox.java", + "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java", + "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java", + "src/main/java/org/tensorflow/ovic/OvicDetector.java", + "src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicdetectionbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/BoundingBox.java", + "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java", + "src/main/java/org/tensorflow/ovic/OvicDetector.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 26349347faebac135ae555e0c5d8219046ab1c29..df77bfaab3251c0ebe2e377e84d11965fdb821dd 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -4,7 +4,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco ## Pre-requisite -Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. +Follow the steps [here](https://www.tensorflow.org/lite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. ## Test the benchmarker: diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index a8d751ade26adc358e130138381eab9956f2d848..f567358ea33966ea8fdb422749662e22111c5fcc 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -6,12 +6,14 @@ licenses(["notice"]) # Apache 2.0 android_binary( name = "ovic_benchmarker_binary", srcs = [ - "OvicBenchmarker.java", "OvicBenchmarkerActivity.java", ], + aapt_version = "aapt", assets = [ - "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt", "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "@tflite_mobilenet_ssd_quant//:detect.tflite", ], assets_dir = "", custom_package = "ovic.demo.app", @@ -25,6 +27,7 @@ android_binary( deps = [ "//tensorflow/contrib/lite/java:tensorflowlite", "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", + "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib", "@androidsdk//com.android.support:support-v13-25.2.0", "@androidsdk//com.android.support:support-v4-25.2.0", ], diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java index 59457c308ad7caa17c52563f6a70df79e8a17914..48c29ecebeed42ac9a2e0bc801cab1fb1f9201e8 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -34,18 +34,19 @@ import java.io.InputStream; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.text.DecimalFormat; -import org.tensorflow.ovic.OvicSingleImageResult; +import org.tensorflow.ovic.OvicBenchmarker; +import org.tensorflow.ovic.OvicClassifierBenchmarker; +import org.tensorflow.ovic.OvicDetectorBenchmarker; /** Class that benchmark image classifier models. */ public class OvicBenchmarkerActivity extends Activity { /** Tag for the {@link Log}. */ private static final String TAG = "OvicBenchmarkerActivity"; - /** Name of the label file stored in Assets. */ - private static final String LABEL_PATH = "labels.txt"; - - private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; - private static final String MODEL_PATH = "float_model.lite"; + /** Name of the task-dependent data files stored in Assets. */ + private static String labelPath = null; + private static String testImagePath = null; + private static String modelPath = null; /** * Each bottom press will launch a benchmarking experiment. The experiment stops when either the * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, @@ -64,8 +65,6 @@ public class OvicBenchmarkerActivity extends Activity { private MappedByteBuffer model = null; private InputStream labelInputStream = null; private OvicBenchmarker benchmarker; - /** Inference result of each iteration. */ - OvicSingleImageResult iterResult = null; private TextView textView = null; // private Button startButton = null; @@ -81,21 +80,31 @@ public class OvicBenchmarkerActivity extends Activity { } private Bitmap loadTestBitmap() throws IOException { - InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + InputStream imageStream = getAssets().open(testImagePath); return BitmapFactory.decodeStream(imageStream); } - public void initializeTest() throws IOException { + public void initializeTest(boolean benchmarkClassification) throws IOException { Log.i(TAG, "Initializing benchmarker."); - benchmarker = new OvicBenchmarker(WALL_TIME); + if (benchmarkClassification) { + benchmarker = new OvicClassifierBenchmarker(WALL_TIME); + labelPath = "labels.txt"; + testImagePath = "test_image_224.jpg"; + modelPath = "quantized_model.lite"; + } else { // Benchmarking detection. + benchmarker = new OvicDetectorBenchmarker(WALL_TIME); + labelPath = "coco_labels.txt"; + testImagePath = "test_image_224.jpg"; + modelPath = "detect.tflite"; + } AssetManager am = getAssets(); - AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + AssetFileDescriptor fileDescriptor = am.openFd(modelPath); FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = modelInputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - labelInputStream = am.open(LABEL_PATH); + labelInputStream = am.open(labelPath); } public Boolean doTestIteration() throws IOException, InterruptedException { @@ -115,24 +124,44 @@ public class OvicBenchmarkerActivity extends Activity { Log.i(TAG, "Going to do test iter."); // Start testing. Bitmap testImageBitmap = loadTestBitmap(); - iterResult = benchmarker.doTestIteration(testImageBitmap); - testImageBitmap.recycle(); - if (iterResult == null) { + try { + if (!benchmarker.processBitmap(testImageBitmap)) { + throw new RuntimeException("Failed to run test."); + } + } catch (Exception e) { + e.printStackTrace(); + throw e; + } finally { + testImageBitmap.recycle(); + } + String iterResultString = benchmarker.getLastResultString(); + if (iterResultString == null) { throw new RuntimeException("Inference failed to produce a result."); } - Log.i(TAG, iterResult.toString()); + Log.i(TAG, iterResultString); return true; } - public void startPressed(View view) throws IOException { - Log.i(TAG, "Start pressed"); + public void detectPressed(View view) throws IOException { + benchmarkSession(false); + } + public void classifyPressed(View view) throws IOException { + benchmarkSession(true); + } + + private void benchmarkSession(boolean benchmarkClassification) throws IOException { try { - initializeTest(); + initializeTest(benchmarkClassification); } catch (IOException e) { Log.e(TAG, "Can't initialize benchmarker.", e); throw e; } String displayText = ""; + if (benchmarkClassification) { + displayText = "Classification benchmark: "; + } else { + displayText = "Detection benchmark: "; + } try { setProcessorAffinity(BIG_CORE_MASK); } catch (IOException e) { @@ -142,7 +171,6 @@ public class OvicBenchmarkerActivity extends Activity { Log.i(TAG, "Successfully initialized benchmarker."); int testIter = 0; Boolean iterSuccess = false; - double totalLatency = 0.0f; while (testIter < MAX_ITERATIONS) { try { iterSuccess = doTestIteration(); @@ -151,23 +179,22 @@ public class OvicBenchmarkerActivity extends Activity { throw e; } catch (InterruptedException e) { Log.e(TAG, "Interrupted at iteration " + testIter); + displayText += e.getMessage() + "\n"; } if (!iterSuccess) { break; } testIter++; - totalLatency += (double) iterResult.latency; } - ; Log.i(TAG, "Benchmarking finished"); if (textView != null) { if (testIter > 0) { textView.setText( displayText - + MODEL_PATH + + modelPath + ": Average latency=" - + df2.format(totalLatency / testIter) + + df2.format(benchmarker.getTotalRunTime() / testIter) + "ms after " + testIter + " runs."); diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml index e9d83bae543ae62ba8749c4c91b36b20bf09a176..1bce60ff7def2b0df9c93a4106a9aafff0009a2f 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -30,14 +30,14 @@ android:layout_height="wrap_content" android:text="@string/initial_status_msg" android:id="@+id/textView" - android:layout_above="@+id/button_start" + android:layout_above="@+id/button_clf_start" android:layout_alignParentTop="true"/>