diff --git a/README.md b/README.md index 3cdb6e478ddf4f18af7f81bb3e321510903beb9d..c66f7e3f3f49ed90e4e75475185585a932049f37 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture lets you deploy computation to one +between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting code. TensorFlow also includes TensorBoard, a data visualization toolkit. @@ -86,6 +86,7 @@ The TensorFlow project strives to abide by generally accepted best practices in * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) diff --git a/RELEASE.md b/RELEASE.md index 6f54dee58f75c29a16545ba25de12fe059baf1eb..e8459531748628fd822d876d79625fdd65798791 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,74 @@ +# Release 1.7.0 + +## Major Features And Improvements +* Eager mode is moving out of contrib, try `tf.enable_eager_execution()`. +* Graph rewrites emulating fixed-point quantization compatible with TensorFlow Lite, supported by new `tf.contrib.quantize` package. +* Easily customize gradient computation with `tf.custom_gradient`. +* [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md), the graphical user interface (GUI) of TensorFlow Debugger (tfdbg), is now in alpha. +* Experimental support for reading a sqlite database as a `Dataset` with new `tf.contrib.data.SqlDataset`. +* Distributed Mutex / CriticalSection added to `tf.contrib.framework.CriticalSection`. +* Better text processing with `tf.regex_replace`. +* Easy, efficient sequence input with `tf.contrib.data.bucket_by_sequence_length` +* Initial support for `tf.contrib.tensorrt` that enables native TensorRT in + TensorFlow. + +## Bug Fixes and Other Changes +* Accelerated Linear Algebra (XLA): + * Add `MaxPoolGradGrad` support for XLA + * CSE pass from Tensorflow is now disabled in XLA. +* `tf.data`: + * `tf.data.Dataset` + * Add support for building C++ Dataset op kernels as external libraries, using the `tf.load_op_library()` mechanism. + * `Dataset.list_files()` now shuffles its output by default. + * `Dataset.shuffle(..., seed=tf.constant(0, dtype=tf.int64))` now yields the same sequence of elements as `Dataset.shuffle(..., seed=0)`. + * Add `num_parallel_reads` argument to `tf.data.TFRecordDataset`. +* `tf.contrib`: + * `tf.contrib.bayesflow.halton_sequence` now supports randomization. + * Add support for scalars in `tf.contrib.all_reduce`. + * Add `effective_sample_size` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `potential_scale_reduction` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `BatchNormalization`, `Kumaraswamy` bijectors. + * Deprecate `tf.contrib.learn`. Please check contrib/learn/README.md for instructions on how to convert existing code. + * `tf.contrib.data` + * Remove deprecated `tf.contrib.data.Dataset`, `tf.contrib.data.Iterator`, `tf.contrib.data.FixedLengthRecordDataset`, `tf.contrib.data.TextLineDataset`, and `tf.contrib.data.TFRecordDataset` classes. + * Added `bucket_by_sequence_length`, `sliding_window_batch`, and `make_batched_features_dataset` + * Remove unmaintained `tf.contrib.ndlstm`. You can find it externally at https://github.com/tmbarchive/tfndlstm. + * Moved most of `tf.contrib.bayesflow` to its own repo: `tfp` +* Other: + * tf.py_func now reports the full stack trace if an exception occurs. + * Integrate `TPUClusterResolver` with GKE's integration for Cloud TPUs. + * Add a library for statistical testing of samplers. + * Add Helpers to stream data from the GCE VM to a Cloud TPU. + * Integrate ClusterResolvers with TPUEstimator. + * Unify metropolis_hastings interface with HMC kernel. + * Move LIBXSMM convolutions to a separate --define flag so that they are disabled by default. + * Fix `MomentumOptimizer` lambda. + * Reduce `tfp.layers` boilerplate via programmable docstrings. + * Add `auc_with_confidence_intervals`, a method for computing the AUC and confidence interval with linearithmic time complexity. + * `regression_head` now accepts customized link function, to satisfy the usage that user can define their own link function if the `array_ops.identity` does not meet the requirement. + * Fix `initialized_value` and `initial_value` behaviors for `ResourceVariables` created from `VariableDef` protos. + * Add TensorSpec to represent the specification of Tensors. + * Constant folding pass is now deterministic. + * Support `float16` `dtype` in `tf.linalg.*`. + * Add `tf.estimator.export.TensorServingInputReceiver` that allows `tf.estimator.Estimator.export_savedmodel` to pass raw tensors to model functions. + +## Deprecations + +* TensorFlow 1.7 may be the last time we support Cuda versions below 8.0. + Starting with TensorFlow 1.8 release, 8.0 will be the minimum supported + version. +* TensorFlow 1.7 may be the last time we support cuDNN versions below 6.0. + Starting with TensorFlow 1.8 release, 6.0 will be the minimum supported + version. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abe, Alistair Low, Andy Kernahan, Appledore, Ben, Ben Barsdell, Boris Pfahringer, Brad Wannow, Brett Koonce, Carl Thomé, cclauss, Chengzhi Chen, Chris Drake, Christopher Yeh, Clayne Robison, Codrut Grosu, Daniel Trebbien, Danny Goodman, David Goodwin, David Norman, Deron Eriksson, Donggeon Lim, Donny Viszneki, DosLin, DylanDmitri, Francisco Guerrero, Fred Reiss, gdh1995, Giuseppe, Glenn Weidner, gracehoney, Guozhong Zhuang, Haichen "Hc" Li, Harald Husum, harumitsu.nobuta, Henry Spivey, hsm207, Jekyll Song, Jerome, Jiongyan Zhang, jjsjann123, John Sungjin Park, Johnson145, JoshVarty, Julian Wolff, Jun Wang, June-One, Kamil Sindi, Kb Sriram, Kdavis-Mozilla, Kenji, lazypanda1, Liang-Chi Hsieh, Loo Rong Jie, Mahesh Bhosale, MandarJKulkarni, ManHyuk, Marcus Ong, Marshal Hayes, Martin Pool, matthieudelaro, mdfaijul, mholzel, Michael Zhou, Ming Li, Minmin Sun, Myungjoo Ham, MyungsungKwak, Naman Kamra, Peng Yu, Penghao Cen, Phil, Raghuraman-K, resec, Rohin Mohanadas, Sandeep N Gupta, Scott Tseng, seaotterman, Seo Sanghyeon, Sergei Lebedev, Ted Chang, terrytangyuan, Tim H, tkunic, Tod, vihanjain, Yan Facai (颜发才), Yin Li, Yong Tang, Yukun Chen, Yusuke Yamada + + + # Release 1.6.0 ## Breaking Changes diff --git a/SECURITY.md b/SECURITY.md index 378e77696725e338e8289cda84dbc543303ae053..a5ce3a62ee202f6e7d83f0fedc2777d9c88ba9b5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,7 +168,18 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. + +In addition, please include the following information along with your report: + +* Your name and affiliation (if any). +* A description the technical details of the vulnerabilities. It is very + important to let us know how we can reproduce your findings. +* An explanation who can exploit this vulnerability, and what they gain when + doing so -- write an attack scenario. This will help us evaluate your report + quickly, especially if the issue is complex. +* Whether this vulnerability public or known to third parties. If it is, please + provide details. If you believe that an existing (public) issue is security-related, please send an email to `security@tensorflow.org`. The email should include the issue ID and @@ -233,7 +244,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|-------------------|:-----------------:|--------------------|-----------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|--------------------|:-----------------:|-----------------------|-----------------------------| +| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/WORKSPACE b/WORKSPACE index 1e38a9a8cd754886fc5232531816b875de0879a3..11c5cdb2070e79b16540a39f13cab28608962340 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,12 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +# We must check the bazel version before trying to parse any other BUILD +# files, in case the parsing of those build files depends on the bazel +# version we require here. +load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") +check_bazel_version_at_least("0.10.0") + load("//tensorflow:workspace.bzl", "tf_workspace") # Uncomment and update the paths in these entries to build the Android demo. diff --git a/configure.py b/configure.py index d14edef1be9e31137c96bed7aebf7ba158b3274f..81d5ad77ee48b101c2f55baf5b3ee935dab756c8 100644 --- a/configure.py +++ b/configure.py @@ -35,12 +35,13 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' +_DEFAULT_NCCL_VERSION = '1.3' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -484,6 +485,8 @@ def set_cc_opt_flags(environ_cp): if is_ppc64le(): # gcc on ppc64le does not support -march, use mcpu instead default_cc_opt_flags = '-mcpu=native' + elif is_windows(): + default_cc_opt_flags = '/arch:AVX' else: default_cc_opt_flags = '-march=native' question = ('Please specify optimization flags to use during compilation when' @@ -494,7 +497,7 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - if not is_ppc64le(): + if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get @@ -502,7 +505,6 @@ def set_cc_opt_flags(environ_cp): write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -524,7 +526,7 @@ def set_tf_cuda_clang(environ_cp): def set_tf_download_clang(environ_cp): """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you want to download a fresh release of clang? (Experimental)' + question = 'Do you wish to download a fresh release of clang? (Experimental)' yes_reply = 'Clang will be downloaded and used to compile tensorflow.' no_reply = 'Clang will not be downloaded.' set_action_env_var( @@ -1103,6 +1105,81 @@ def set_tf_tensorrt_install_path(environ_cp): write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) +def set_tf_nccl_install_path(environ_cp): + """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently NCCL is only supported on Linux platforms.') + + ask_nccl_version = ( + 'Please specify the NCCL version you want to use. ' + '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + tf_nccl_version = get_from_env_or_user_or_default( + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + + if tf_nccl_version == '1': + break # No need to get install path, NCCL 1 is a GitHub repo. + + # TODO(csigg): Look with ldconfig first if we can find the library in paths + # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding + # include directory. This is where the NCCL .deb packages install them. + # Then ask the user if we should use that. Instead of a single + # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to + # nccl_configure.bzl + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = (r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + break + + # Reset and Retry + print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, + nccl_hdr_path)) + + environ_cp['TF_NCCL_VERSION'] = '' + else: + raise UserInputError('Invalid TF_NCCL setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TF_NCCL_VERSION + environ_cp['TF_NCCL_VERSION'] = tf_nccl_version + write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1397,6 +1474,9 @@ def main(): environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1411,7 +1491,7 @@ def main(): set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', False, 'kafka') + 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1436,6 +1516,7 @@ def main(): set_tf_cudnn_version(environ_cp) if is_linux(): set_tf_tensorrt_install_path(environ_cp) + set_tf_nccl_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 'LD_LIBRARY_PATH') != '1': @@ -1444,16 +1525,8 @@ def main(): set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - if not is_windows(): - # Ask if we want to download clang release while building. - set_tf_download_clang(environ_cp) - else: - # We use bazel's generated crosstool on Windows and there is no - # way to provide downloaded toolchain for that yet. - # TODO(ibiryukov): Investigate using clang as a cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - + # Ask whether we should download the clang toolchain. + set_tf_download_clang(environ_cp) if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) @@ -1463,6 +1536,13 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) + else: + # CUDA not required. Ask whether we should download the clang toolchain and + # use it for the CPU build. + set_tf_download_clang(environ_cp) + if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': + write_to_bazelrc('build --config=download_clang') + write_to_bazelrc('test --config=download_clang') set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9932e5607685b5b8f5900bdfa42363151e57d3f1..823393ebdf1f4b658361f31963a275a683e61002 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -240,6 +240,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support_windows_override", + define_values = {"with_kafka_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -394,19 +401,6 @@ package_group( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "tensorflow_py", srcs = ["__init__.py"], @@ -426,289 +420,6 @@ py_library( ], ) -filegroup( - name = "all_opensource_files", - data = [ - ":all_files", - "//tensorflow/c:all_files", - "//tensorflow/cc:all_files", - "//tensorflow/cc/saved_model:all_files", - "//tensorflow/cc/saved_model/python:all_files", - "//tensorflow/cc/tools:all_files", - "//tensorflow/compiler/aot:all_files", - "//tensorflow/compiler/aot/tests:all_files", - "//tensorflow/compiler/jit:all_files", - "//tensorflow/compiler/jit/graphcycles:all_files", - "//tensorflow/compiler/jit/kernels:all_files", - "//tensorflow/compiler/jit/legacy_flags:all_files", - "//tensorflow/compiler/jit/ops:all_files", - "//tensorflow/compiler/plugin:all_files", - "//tensorflow/compiler/tests:all_files", - "//tensorflow/compiler/tf2xla:all_files", - "//tensorflow/compiler/tf2xla/cc:all_files", - "//tensorflow/compiler/tf2xla/kernels:all_files", - "//tensorflow/compiler/tf2xla/lib:all_files", - "//tensorflow/compiler/tf2xla/ops:all_files", - "//tensorflow/compiler/xla:all_files", - "//tensorflow/compiler/xla/client:all_files", - "//tensorflow/compiler/xla/client/lib:all_files", - "//tensorflow/compiler/xla/client/xla_client:all_files", - "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/python:all_files", - "//tensorflow/compiler/xla/service:all_files", - "//tensorflow/compiler/xla/service/cpu:all_files", - "//tensorflow/compiler/xla/service/gpu:all_files", - "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", - "//tensorflow/compiler/xla/service/interpreter:all_files", - "//tensorflow/compiler/xla/service/llvm_ir:all_files", - "//tensorflow/compiler/xla/tests:all_files", - "//tensorflow/compiler/xla/tools:all_files", - "//tensorflow/compiler/xla/tools/parser:all_files", - "//tensorflow/contrib:all_files", - "//tensorflow/contrib/all_reduce:all_files", - "//tensorflow/contrib/android:all_files", - "//tensorflow/contrib/batching:all_files", - "//tensorflow/contrib/bayesflow:all_files", - "//tensorflow/contrib/boosted_trees:all_files", - "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", - "//tensorflow/contrib/boosted_trees/lib:all_files", - "//tensorflow/contrib/boosted_trees/proto:all_files", - "//tensorflow/contrib/boosted_trees/resources:all_files", - "//tensorflow/contrib/cloud:all_files", - "//tensorflow/contrib/cloud/kernels:all_files", - "//tensorflow/contrib/cluster_resolver:all_files", - "//tensorflow/contrib/coder:all_files", - "//tensorflow/contrib/compiler:all_files", - "//tensorflow/contrib/copy_graph:all_files", - "//tensorflow/contrib/crf:all_files", - "//tensorflow/contrib/cudnn_rnn:all_files", - "//tensorflow/contrib/data:all_files", - "//tensorflow/contrib/data/kernels:all_files", - "//tensorflow/contrib/data/python/kernel_tests:all_files", - "//tensorflow/contrib/data/python/ops:all_files", - "//tensorflow/contrib/decision_trees/proto:all_files", - "//tensorflow/contrib/deprecated:all_files", - "//tensorflow/contrib/distributions:all_files", - "//tensorflow/contrib/eager/proto:all_files", - "//tensorflow/contrib/eager/python:all_files", - "//tensorflow/contrib/estimator:all_files", - "//tensorflow/contrib/factorization:all_files", - "//tensorflow/contrib/factorization/examples:all_files", - "//tensorflow/contrib/factorization/kernels:all_files", - "//tensorflow/contrib/feature_column:all_files", - "//tensorflow/contrib/ffmpeg:all_files", - "//tensorflow/contrib/ffmpeg/default:all_files", - "//tensorflow/contrib/framework:all_files", - "//tensorflow/contrib/fused_conv:all_files", - "//tensorflow/contrib/gan:all_files", - "//tensorflow/contrib/gdr:all_files", - "//tensorflow/contrib/graph_editor:all_files", - "//tensorflow/contrib/grid_rnn:all_files", - "//tensorflow/contrib/hooks:all_files", - "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files", - "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", - "//tensorflow/contrib/image:all_files", - "//tensorflow/contrib/input_pipeline:all_files", - "//tensorflow/contrib/input_pipeline/kernels:all_files", - "//tensorflow/contrib/integrate:all_files", - "//tensorflow/contrib/keras:all_files", - "//tensorflow/contrib/kernel_methods:all_files", - "//tensorflow/contrib/kfac:all_files", - "//tensorflow/contrib/kfac/examples:all_files", - "//tensorflow/contrib/kfac/examples/tests:all_files", - "//tensorflow/contrib/kfac/python/kernel_tests:all_files", - "//tensorflow/contrib/kfac/python/ops:all_files", - "//tensorflow/contrib/labeled_tensor:all_files", - "//tensorflow/contrib/layers:all_files", - "//tensorflow/contrib/layers/kernels:all_files", - "//tensorflow/contrib/learn:all_files", - "//tensorflow/contrib/learn/python/learn/datasets:all_files", - "//tensorflow/contrib/legacy_seq2seq:all_files", - "//tensorflow/contrib/libsvm:all_files", - "//tensorflow/contrib/linalg:all_files", - "//tensorflow/contrib/linear_optimizer:all_files", - "//tensorflow/contrib/lite:all_files", - "//tensorflow/contrib/lite/java:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", - "//tensorflow/contrib/lite/java/src/main/native:all_files", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", - "//tensorflow/contrib/lite/kernels:all_files", - "//tensorflow/contrib/lite/kernels/internal:all_files", - "//tensorflow/contrib/lite/models/smartreply:all_files", - "//tensorflow/contrib/lite/nnapi:all_files", - "//tensorflow/contrib/lite/python:all_files", - "//tensorflow/contrib/lite/schema:all_files", - "//tensorflow/contrib/lite/testing:all_files", - "//tensorflow/contrib/lite/toco:all_files", - "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", - "//tensorflow/contrib/lite/toco/python:all_files", - "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", - "//tensorflow/contrib/lite/toco/tflite:all_files", - "//tensorflow/contrib/lite/tools:all_files", - "//tensorflow/contrib/lookup:all_files", - "//tensorflow/contrib/losses:all_files", - "//tensorflow/contrib/makefile:all_files", - "//tensorflow/contrib/memory_stats:all_files", - "//tensorflow/contrib/meta_graph_transform:all_files", - "//tensorflow/contrib/metrics:all_files", - "//tensorflow/contrib/model_pruning:all_files", - "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", - "//tensorflow/contrib/nccl:all_files", - "//tensorflow/contrib/nearest_neighbor:all_files", - "//tensorflow/contrib/nn:all_files", - "//tensorflow/contrib/opt:all_files", - "//tensorflow/contrib/periodic_resample:all_files", - "//tensorflow/contrib/predictor:all_files", - "//tensorflow/contrib/py2tf:all_files", - "//tensorflow/contrib/py2tf/converters:all_files", - "//tensorflow/contrib/py2tf/impl:all_files", - "//tensorflow/contrib/py2tf/pyct:all_files", - "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", - "//tensorflow/contrib/py2tf/utils:all_files", - "//tensorflow/contrib/quantize:all_files", - "//tensorflow/contrib/receptive_field:all_files", - "//tensorflow/contrib/reduce_slice_ops:all_files", - "//tensorflow/contrib/remote_fused_graph/pylib:all_files", - "//tensorflow/contrib/resampler:all_files", - "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/saved_model:all_files", - "//tensorflow/contrib/saved_model/cc/saved_model:all_files", - "//tensorflow/contrib/seq2seq:all_files", - "//tensorflow/contrib/session_bundle:all_files", - "//tensorflow/contrib/session_bundle/example:all_files", - "//tensorflow/contrib/signal:all_files", - "//tensorflow/contrib/slim:all_files", - "//tensorflow/contrib/slim/python/slim/data:all_files", - "//tensorflow/contrib/slim/python/slim/nets:all_files", - "//tensorflow/contrib/solvers:all_files", - "//tensorflow/contrib/sparsemax:all_files", - "//tensorflow/contrib/specs:all_files", - "//tensorflow/contrib/staging:all_files", - "//tensorflow/contrib/stat_summarizer:all_files", - "//tensorflow/contrib/stateless:all_files", - "//tensorflow/contrib/summary:all_files", - "//tensorflow/contrib/tensor_forest:all_files", - "//tensorflow/contrib/tensor_forest/hybrid:all_files", - "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", - "//tensorflow/contrib/tensor_forest/proto:all_files", - "//tensorflow/contrib/tensorboard:all_files", - "//tensorflow/contrib/tensorboard/db:all_files", - "//tensorflow/contrib/tensorrt:all_files", - "//tensorflow/contrib/testing:all_files", - "//tensorflow/contrib/text:all_files", - "//tensorflow/contrib/tfprof:all_files", - "//tensorflow/contrib/timeseries:all_files", - "//tensorflow/contrib/timeseries/examples:all_files", - "//tensorflow/contrib/timeseries/python/timeseries:all_files", - "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", - "//tensorflow/contrib/tpu:all_files", - "//tensorflow/contrib/tpu/profiler:all_files", - "//tensorflow/contrib/tpu/proto:all_files", - "//tensorflow/contrib/training:all_files", - "//tensorflow/contrib/util:all_files", - "//tensorflow/contrib/verbs:all_files", - "//tensorflow/core:all_files", - "//tensorflow/core/api_def:all_files", - "//tensorflow/core/debug:all_files", - "//tensorflow/core/distributed_runtime:all_files", - "//tensorflow/core/distributed_runtime/rpc:all_files", - "//tensorflow/core/grappler:all_files", - "//tensorflow/core/grappler/clusters:all_files", - "//tensorflow/core/grappler/costs:all_files", - "//tensorflow/core/grappler/inputs:all_files", - "//tensorflow/core/grappler/optimizers:all_files", - "//tensorflow/core/grappler/utils:all_files", - "//tensorflow/core/kernels:all_files", - "//tensorflow/core/kernels/batching_util:all_files", - "//tensorflow/core/kernels/data:all_files", - "//tensorflow/core/kernels/data/sql:all_files", - "//tensorflow/core/kernels/fuzzing:all_files", - "//tensorflow/core/kernels/hexagon:all_files", - "//tensorflow/core/kernels/neon:all_files", - "//tensorflow/core/lib/db:all_files", - "//tensorflow/core/ops/compat:all_files", - "//tensorflow/core/platform/cloud:all_files", - "//tensorflow/core/platform/default/build_config:all_files", - "//tensorflow/core/platform/hadoop:all_files", - "//tensorflow/core/platform/s3:all_files", - "//tensorflow/core/profiler:all_files", - "//tensorflow/core/profiler/internal:all_files", - "//tensorflow/core/profiler/internal/advisor:all_files", - "//tensorflow/core/util/ctc:all_files", - "//tensorflow/core/util/tensor_bundle:all_files", - "//tensorflow/examples/adding_an_op:all_files", - "//tensorflow/examples/android:all_files", - "//tensorflow/examples/benchmark:all_files", - "//tensorflow/examples/get_started/regression:all_files", - "//tensorflow/examples/how_tos/reading_data:all_files", - "//tensorflow/examples/image_retraining:all_files", - "//tensorflow/examples/label_image:all_files", - "//tensorflow/examples/learn:all_files", - "//tensorflow/examples/multibox_detector:all_files", - "//tensorflow/examples/saved_model:all_files", - "//tensorflow/examples/speech_commands:all_files", - "//tensorflow/examples/tutorials/estimators:all_files", - "//tensorflow/examples/tutorials/layers:all_files", - "//tensorflow/examples/tutorials/mnist:all_files", - "//tensorflow/examples/tutorials/monitors:all_files", - "//tensorflow/examples/tutorials/word2vec:all_files", - "//tensorflow/examples/wav_to_spectrogram:all_files", - "//tensorflow/go:all_files", - "//tensorflow/java:all_files", - "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", - "//tensorflow/java/src/main/native:all_files", - "//tensorflow/python:all_files", - "//tensorflow/python/data:all_files", - "//tensorflow/python/data/kernel_tests:all_files", - "//tensorflow/python/data/ops:all_files", - "//tensorflow/python/data/util:all_files", - "//tensorflow/python/debug:all_files", - "//tensorflow/python/eager:all_files", - "//tensorflow/python/estimator:all_files", - "//tensorflow/python/feature_column:all_files", - "//tensorflow/python/keras:all_files", - "//tensorflow/python/kernel_tests:all_files", - "//tensorflow/python/kernel_tests/distributions:all_files", - "//tensorflow/python/kernel_tests/linalg:all_files", - "//tensorflow/python/kernel_tests/random:all_files", - "//tensorflow/python/ops/distributions:all_files", - "//tensorflow/python/ops/linalg:all_files", - "//tensorflow/python/ops/losses:all_files", - "//tensorflow/python/profiler:all_files", - "//tensorflow/python/profiler/internal:all_files", - "//tensorflow/python/saved_model:all_files", - "//tensorflow/python/tools:all_files", - "//tensorflow/tools/api/generator:all_files", - "//tensorflow/tools/api/golden:all_files", - "//tensorflow/tools/api/lib:all_files", - "//tensorflow/tools/api/tests:all_files", - "//tensorflow/tools/benchmark:all_files", - "//tensorflow/tools/build_info:all_files", - "//tensorflow/tools/ci_build/gpu_build:all_files", - "//tensorflow/tools/common:all_files", - "//tensorflow/tools/compatibility:all_files", - "//tensorflow/tools/dist_test/server:all_files", - "//tensorflow/tools/docker:all_files", - "//tensorflow/tools/docker/notebooks:all_files", - "//tensorflow/tools/docs:all_files", - "//tensorflow/tools/git:all_files", - "//tensorflow/tools/graph_transforms:all_files", - "//tensorflow/tools/mlpbtxt:all_files", - "//tensorflow/tools/proto_text:all_files", - "//tensorflow/tools/quantization:all_files", - "//tensorflow/tools/test:all_files", - "//tensorflow/user_ops:all_files", - "//third_party/eigen3:all_files", - "//third_party/fft2d:all_files", - "//third_party/flatbuffers:all_files", - "//third_party/hadoop:all_files", - "//third_party/sycl:all_files", - "//third_party/sycl/sycl:all_files", - ], - visibility = ["//visibility:public"], -) - load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -785,7 +496,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow/c:exported_symbols.lds", + "$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], @@ -794,7 +505,7 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow/c:version_script.lds", + "$(location //tensorflow/c:version_script.lds)", ], }), deps = [ @@ -812,7 +523,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow:tf_exported_symbols.lds", + "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], @@ -820,7 +531,7 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow:tf_version_script.lds", + "$(location //tensorflow:tf_version_script.lds)", ], }), deps = [ diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 29ed957c9aa8cbe515f5f43bdccbf8c94f47c459..2367014cd02c721ea96581919c3efc96e772d9a6 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -34,6 +34,8 @@ filegroup( exclude = [ "c_api_experimental.cc", "c_api_experimental.h", + "python_api.cc", + "python_api.h", "*test*", ], ), @@ -116,6 +118,10 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/contrib/tpu:all_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) @@ -212,6 +218,27 @@ tf_cuda_cc_test( ], ) +tf_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["testdata/tf_record"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api_experimental", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "c_api_function_test", size = "small", @@ -256,20 +283,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + # TODO(b/74620627): remove when _USE_C_SHAPES is removed + "//tensorflow/python:cpp_shape_inference_proto_cc", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 778cb667e2c0015c6a768ecf3b12b82601764117..18eeb2816807ec9986999cfc2c9a4c0f032683c0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -647,11 +647,11 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, for (auto it : graph->sessions) { mutex_lock session_lock(it.first->mu); if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = FailedPrecondition( + it.second = strings::StrCat( "Operation '", op.node.DebugString(), "' was changed by ", mutation_type, - " after it was run by a session. Nodes can be mutated " - "only before they are executed by a session. Either don't modify " + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " "nodes after running them or create a new session."); } } @@ -722,10 +722,11 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); const Graph& graph = session->graph->graph; - status->status = session->graph->sessions[session]; - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); } const auto num_nodes = graph.num_node_ids(); @@ -2475,7 +2476,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->sessions[new_session] = Status::OK(); + graph->sessions[new_session] = ""; } return new_session; } else { @@ -2541,7 +2542,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->sessions[session] = Status::OK(); + graph->sessions[session] = ""; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index b32f574628c4d1dc5c3bb3f1265a1b12adee28bc..fe85f8ee0ed2c58c3ba9201a9ca895c9ec48c022 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1496,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieves the type of the device at the given index. // @@ -1506,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieve the amount of memory associated with a given device. // // If index is out of bounds, an error code will be set in the status object, // and -1 will be returned. TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status*); + const TF_DeviceList* list, int index, TF_Status* status); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index be7f85a5bb06dce84579b109d506ded049042b50..bea93785717e2161fcec941485ac3c3f7f3e3ed5 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,8 +17,26 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/protobuf/config.pb.h" +using tensorflow::FunctionDef; +using tensorflow::Node; +using tensorflow::NodeBuilder; +using tensorflow::Status; + +namespace { +typedef std::unique_ptr + UniqueFuncPtr; +} + +// struct TF_Operation { tensorflow::Node node; }; +static TF_Operation* ToTF_Operation(Node* node) { + return static_cast(static_cast(node)); +} + void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { tensorflow::ConfigProto& config = options->options.config; auto* optimizer_options = @@ -37,3 +55,8340 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); } } + +void TF_InitializeTPU(TF_Session* session, TF_Status* status) { + VLOG(1) << "Initializing TPU"; + TF_Operation* config_op = + TF_GraphOperationByName(session->graph, "ConfigureDistributedTPU"); + if (config_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find node ConfigureDistributedTPU in the TF graph."); + return; + } + + TF_Output config_node{config_op, 0}; + + TF_Tensor* dummy_output; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ &config_node, /*output_values*/ &dummy_output, + /*noutputs*/ 1, + /*targets*/ nullptr, /*ntargets*/ 0, + /*run_metadata*/ nullptr, status); + if (status->status.ok()) { + TF_DeleteTensor(dummy_output); + } +} + +void TF_ShutdownTPU(TF_Session* session, TF_Status* status) { + { + tensorflow::mutex_lock c(session->graph->mu); + VLOG(1) << "Shutting down TPU, with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + } + + TF_Operation* shutdown_op = + TF_GraphOperationByName(session->graph, "ShutdownDistributedTPU"); + if (shutdown_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find node ShutdownDistributedTPU in the TF graph."); + return; + } + + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, + /*noutputs*/ 0, + /*targets*/ &shutdown_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); +} + +const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { + tensorflow::mutex_lock c(graph->mu); + const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + +// On success, returns a set of TF_Function instances from `text_proto` of +// GraphDef type. These functions must be deleted by calling TF_DeleteFunction. +// +// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto, +// before creating a TF_Function out of the possibly mutated proto. +static std::vector CreateFunctionsFromTextProto( + const char* text_proto, + std::function* mutate_proto_func, TF_Status* status) { + tensorflow::GraphDef gdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for GraphDef: ", text_proto); + return {}; + } + const auto& fdef_lib = gdef.library(); + if (fdef_lib.gradient_size() > 0) { + status->status = tensorflow::errors::Internal( + "GradientDef is not supported in reading Dataset related functions: ", + text_proto); + return {}; + } + std::vector ret; + for (const FunctionDef& fdef : fdef_lib.function()) { + // Make a copy so that we can mutate it. + FunctionDef fdef_to_load = fdef; + if (mutate_proto_func) { + (*mutate_proto_func)(&fdef_to_load); + } + VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString(); + std::vector binary_proto_buf(fdef_to_load.ByteSizeLong()); + fdef_to_load.SerializeToArray(binary_proto_buf.data(), + binary_proto_buf.size()); + TF_Function* func = TF_FunctionImportFunctionDef( + binary_proto_buf.data(), binary_proto_buf.size(), status); + if (!status->status.ok()) return {}; + ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction)); + } + return ret; +} + +// On success, returns a newly created TF_Function instance encoding a dataset +// node stack that returns a sequence of 3 floats, and sets `dataset_name` to +// the created dataset name. The returned function must be deleted by calling +// TF_DeleteFunction. +static UniqueFuncPtr CreateFakeDatasetFunction(std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "_make_dataset_d8de2712" + output_arg { + name: "TensorSliceDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000(B\000\000,B\000\0000B" + } + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/tensors/component_0:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "TensorSliceDataset" + value: "TensorSliceDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_d8de2712"; + auto functions = CreateFunctionsFromTextProto( + func_def, /*mutate_proto_func*/ nullptr, status); + DCHECK_EQ(functions.size(), 1); + return std::move(functions[0]); +} + +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateImagenetDatasetFunctions( + const char* file_path, std::string* dataset_name, TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_91295dea" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "FlatMapDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "flat_filenames/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node_def { + name: "flat_filenames" + op: "Reshape" + input: "arg0" + input: "flat_filenames/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "flat_filenames:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "FlatMapDataset" + op: "FlatMapDataset" + input: "TensorSliceDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_0cc8c35b" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + ret { + key: "FlatMapDataset" + value: "FlatMapDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_0cc8c35b" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "TFRecordDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "compression_type" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8388608 + } + } + } + } + node_def { + name: "TFRecordDataset" + op: "TFRecordDataset" + input: "arg0" + input: "compression_type:output:0" + input: "buffer_size:output:0" + } + ret { + key: "TFRecordDataset" + value: "TFRecordDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_74b6b15c" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "Reshape_1" + type: DT_FLOAT + } + output_arg { + name: "sub_1" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "ParseSingleExample/key_image/class/label" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape" + op: "Reshape" + input: "ParseSingleExample/key_image/class/label:output:0" + input: "ParseSingleExample/Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/class/text" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1" + op: "Reshape" + input: "ParseSingleExample/key_image/class/text:output:0" + input: "ParseSingleExample/Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/encoded" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2" + op: "Reshape" + input: "ParseSingleExample/key_image/encoded:output:0" + input: "ParseSingleExample/Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/format" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "jpeg" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3" + op: "Reshape" + input: "ParseSingleExample/key_image/format:output:0" + input: "ParseSingleExample/Reshape_3/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/ParseSingleExample" + op: "ParseSingleExample" + input: "arg0" + input: "ParseSingleExample/Reshape:output:0" + input: "ParseSingleExample/Reshape_1:output:0" + input: "ParseSingleExample/Reshape_2:output:0" + input: "ParseSingleExample/Reshape_3:output:0" + attr { + key: "Tdense" + value { + list { + type: DT_INT64 + type: DT_STRING + type: DT_STRING + type: DT_STRING + } + } + } + attr { + key: "dense_keys" + value { + list { + s: "image/class/label" + s: "image/class/text" + s: "image/encoded" + s: "image/format" + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + } + shape { + } + shape { + } + shape { + } + } + } + } + attr { + key: "num_sparse" + value { + i: 5 + } + } + attr { + key: "sparse_keys" + value { + list { + s: "image/object/bbox/xmax" + s: "image/object/bbox/xmin" + s: "image/object/bbox/ymax" + s: "image/object/bbox/ymin" + s: "image/object/class/label" + } + } + } + attr { + key: "sparse_types" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_INT64 + } + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:2" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/Substr/pos:output:0" + input: "decode_image/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/is_jpeg/Substr/pos:output:0" + input: "decode_image/is_jpeg/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\377\330\377" + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal" + op: "Equal" + input: "decode_image/is_jpeg/Substr:output:0" + input: "decode_image/is_jpeg/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/Switch" + op: "Switch" + input: "decode_image/is_jpeg/Equal:z:0" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/pred_id" + op: "Identity" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/check_jpeg_channels/x:output:0" + input: "decode_image/cond_jpeg/check_jpeg_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/check_jpeg_channels:z:0" + input: "decode_image/cond_jpeg/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg" + op: "DecodeJpeg" + input: "decode_image/cond_jpeg/DecodeJpeg/Switch:output_true:0" + input: "^decode_image/cond_jpeg/Assert/Assert" + attr { + key: "acceptable_fraction" + value { + f: 1.0 + } + } + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dct_method" + value { + s: "" + } + } + attr { + key: "fancy_upscaling" + value { + b: true + } + } + attr { + key: "ratio" + value { + i: 1 + } + } + attr { + key: "try_recover_truncated" + value { + b: false + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\211PN" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png" + op: "Equal" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/is_png/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/Switch" + op: "Switch" + input: "decode_image/Substr:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png:z:0" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng" + op: "DecodePng" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1:output_true:0" + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dtype" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "GIF" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/is_gif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/is_gif/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd" + op: "LogicalAnd" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1:z:0" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif" + op: "DecodeGif" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1:output_true:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr" + op: "Substr" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "BM" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp" + op: "DecodeBmp" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + attr { + key: "channels" + value { + i: 0 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp:image:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Merge:output:0" + input: "decode_image/cond_jpeg/cond_png/DecodePng:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/Merge:output:0" + input: "decode_image/cond_jpeg/DecodeJpeg:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/Cast" + op: "Cast" + input: "decode_image/cond_jpeg/Merge:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.00392156885937 + } + } + } + } + node_def { + name: "convert_image" + op: "Mul" + input: "convert_image/Cast:y:0" + input: "convert_image/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 4 + } + } + tensor_content: "\000\000\000\000\000\000\000\000\000\000\200?\000\000\200?" + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.10000000149 + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2" + op: "SampleDistortedBoundingBoxV2" + input: "distorted_bounding_box_crop/Shape:output:0" + input: "Const:output:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "area_range" + value { + list { + f: 0.0799999982119 + f: 1.0 + } + } + } + attr { + key: "aspect_ratio_range" + value { + list { + f: 0.75 + f: 1.33333337307 + } + } + } + attr { + key: "max_attempts" + value { + i: 1 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "use_image_if_no_bounding_boxes" + value { + b: true + } + } + } + node_def { + name: "distorted_bounding_box_crop/Slice" + op: "Slice" + input: "convert_image:z:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:begin:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:size:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Shape_1" + op: "Shape" + input: "distorted_bounding_box_crop/Slice:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "Shape:output:0" + input: "Shape_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "Equal:z:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + } + node_def { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "Sum" + op: "Sum" + input: "Cast:y:0" + input: "Const_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node_def { + name: "GreaterEqual/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "GreaterEqual" + op: "GreaterEqual" + input: "Sum:output:0" + input: "GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Switch" + op: "Switch" + input: "GreaterEqual:z:0" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_f" + op: "Identity" + input: "cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/pred_id" + op: "Identity" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/Shape" + op: "Shape" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Shape/Switch" + op: "Switch" + input: "convert_image:z:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@convert_image" + } + } + } + } + node_def { + name: "cond/Cast" + op: "Cast" + input: "cond/Shape:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice/stack:output:0" + input: "cond/strided_slice/stack_1:output:0" + input: "cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/strided_slice_1/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice_1/stack:output:0" + input: "cond/strided_slice_1/stack_1:output:0" + input: "cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Greater" + op: "Greater" + input: "cond/strided_slice:output:0" + input: "cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Switch" + op: "Switch" + input: "cond/Greater:z:0" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_t" + op: "Identity" + input: "cond/cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_f" + op: "Identity" + input: "cond/cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/pred_id" + op: "Identity" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/strided_slice/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice/stack:output:0" + input: "cond/cond/strided_slice/stack_1:output:0" + input: "cond/cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice_1/stack:output:0" + input: "cond/cond/strided_slice_1/stack_1:output:0" + input: "cond/cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv" + op: "RealDiv" + input: "cond/cond/strided_slice:output:0" + input: "cond/cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul/y" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul" + op: "Mul" + input: "cond/cond/truediv:z:0" + input: "cond/cond/mul/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast/x/1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast/x" + op: "Pack" + input: "cond/cond/mul:z:0" + input: "cond/cond/Cast/x/1:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast" + op: "Cast" + input: "cond/cond/Cast/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_2/stack:output:0" + input: "cond/cond/strided_slice_2/stack_1:output:0" + input: "cond/cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice_2/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_3/stack:output:0" + input: "cond/cond/strided_slice_3/stack_1:output:0" + input: "cond/cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv_1" + op: "RealDiv" + input: "cond/cond/strided_slice_2:output:0" + input: "cond/cond/strided_slice_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul_1/y" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul_1" + op: "Mul" + input: "cond/cond/truediv_1:z:0" + input: "cond/cond/mul_1/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast_1/x/0" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast_1/x" + op: "Pack" + input: "cond/cond/Cast_1/x/0:output:0" + input: "cond/cond/mul_1:z:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast_1" + op: "Cast" + input: "cond/cond/Cast_1/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Merge" + op: "Merge" + input: "cond/cond/Cast_1:y:0" + input: "cond/cond/Cast:y:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic/images" + op: "Pack" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic" + op: "ResizeBicubic" + input: "cond/ResizeBicubic/images:output:0" + input: "cond/cond/Merge:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_2/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2" + op: "StridedSlice" + input: "cond/ResizeBicubic:resized_images:0" + input: "cond/strided_slice_2/stack:output:0" + input: "cond/strided_slice_2/stack_1:output:0" + input: "cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_1" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_3/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3" + op: "StridedSlice" + input: "cond/Shape_1:output:0" + input: "cond/strided_slice_3/stack:output:0" + input: "cond/strided_slice_3/stack_1:output:0" + input: "cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_2" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_4/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4" + op: "StridedSlice" + input: "cond/Shape_2:output:0" + input: "cond/strided_slice_4/stack:output:0" + input: "cond/strided_slice_4/stack_1:output:0" + input: "cond/strided_slice_4/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/sub/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub" + op: "Sub" + input: "cond/strided_slice_3:output:0" + input: "cond/sub/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add" + op: "Add" + input: "cond/sub:z:0" + input: "cond/add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv/Cast" + op: "Cast" + input: "cond/add:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/Cast_1" + op: "Cast" + input: "cond/truediv/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv" + op: "RealDiv" + input: "cond/truediv/Cast:y:0" + input: "cond/truediv/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/sub_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub_1" + op: "Sub" + input: "cond/strided_slice_4:output:0" + input: "cond/sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add_1" + op: "Add" + input: "cond/sub_1:z:0" + input: "cond/add_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv_1/Cast" + op: "Cast" + input: "cond/add_1:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/Cast_1" + op: "Cast" + input: "cond/truediv_1/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1" + op: "RealDiv" + input: "cond/truediv_1/Cast:y:0" + input: "cond/truediv_1/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Shape_3" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Rank" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal" + op: "Equal" + input: "cond/Rank:output:0" + input: "cond/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Assert/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert" + op: "Assert" + input: "cond/Equal:z:0" + input: "cond/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/strided_slice_5/stack" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_2" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_5" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_5/stack:output:0" + input: "cond/strided_slice_5/stack_1:output:0" + input: "cond/strided_slice_5/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/stack/0" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack/1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack" + op: "Pack" + input: "cond/stack/0:output:0" + input: "cond/stack/1:output:0" + input: "cond/strided_slice_5:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/strided_slice_6/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_6/stack:output:0" + input: "cond/strided_slice_6/stack_1:output:0" + input: "cond/strided_slice_6/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual" + op: "GreaterEqual" + input: "cond/strided_slice_6:output:0" + input: "cond/GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_7/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_7/stack:output:0" + input: "cond/strided_slice_7/stack_1:output:0" + input: "cond/strided_slice_7/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual_1" + op: "GreaterEqual" + input: "cond/strided_slice_7:output:0" + input: "cond/GreaterEqual_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/LogicalAnd" + op: "LogicalAnd" + input: "cond/GreaterEqual:z:0" + input: "cond/GreaterEqual_1:z:0" + } + node_def { + name: "cond/Assert_1/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert" + op: "Assert" + input: "cond/LogicalAnd:z:0" + input: "cond/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/stack_1/2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_DOUBLE + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_DOUBLE + tensor_shape { + } + double_val: 0.0 + } + } + } + } + node_def { + name: "cond/stack_1" + op: "Pack" + input: "cond/truediv:z:0" + input: "cond/truediv_1:z:0" + input: "cond/stack_1/2:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ToInt32" + op: "Cast" + input: "cond/stack_1:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Slice" + op: "Slice" + input: "cond/strided_slice_2:output:0" + input: "cond/ToInt32:y:0" + input: "cond/stack:output:0" + input: "^cond/Assert_1/Assert" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/Reshape" + op: "Reshape" + input: "cond/Slice:output:0" + input: "cond/stack:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images" + op: "Pack" + input: "cond/ResizeBicubic_1/images/Switch:output_false:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images/Switch" + op: "Switch" + input: "distorted_bounding_box_crop/Slice:output:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@distorted_bounding_box_crop/Slice" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1/size" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\340\000\000\000\340\000\000\000" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1" + op: "ResizeBicubic" + input: "cond/ResizeBicubic_1/images:output:0" + input: "cond/ResizeBicubic_1/size:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_8/stack" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_1" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_2" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8" + op: "StridedSlice" + input: "cond/ResizeBicubic_1:resized_images:0" + input: "cond/strided_slice_8/stack:output:0" + input: "cond/strided_slice_8/stack_1:output:0" + input: "cond/strided_slice_8/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Merge" + op: "Merge" + input: "cond/strided_slice_8:output:0" + input: "cond/Reshape:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\354Q\370>\325x\351>;\337\317>" + } + } + } + } + node_def { + name: "sub" + op: "Sub" + input: "cond/Merge:output:0" + input: "Const_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\372~j>B`e>fff>" + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "sub:z:0" + input: "Const_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/control_dependency" + op: "Identity" + input: "truediv:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/min" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/max" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/RandomUniform" + op: "RandomUniform" + input: "random_flip_left_right/random_uniform/shape:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/sub" + op: "Sub" + input: "random_flip_left_right/random_uniform/max:output:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/mul" + op: "Mul" + input: "random_flip_left_right/random_uniform/RandomUniform:output:0" + input: "random_flip_left_right/random_uniform/sub:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform" + op: "Add" + input: "random_flip_left_right/random_uniform/mul:z:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Less/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node_def { + name: "random_flip_left_right/Less" + op: "Less" + input: "random_flip_left_right/random_uniform:z:0" + input: "random_flip_left_right/Less/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Switch" + op: "Switch" + input: "random_flip_left_right/Less:z:0" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_t" + op: "Identity" + input: "random_flip_left_right/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_f" + op: "Identity" + input: "random_flip_left_right/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/pred_id" + op: "Identity" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/axis" + op: "Const" + input: "^random_flip_left_right/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2" + op: "ReverseV2" + input: "random_flip_left_right/ReverseV2/Switch:output_true:0" + input: "random_flip_left_right/ReverseV2/axis:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/Switch" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Switch_1" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Merge" + op: "Merge" + input: "random_flip_left_right/Switch_1:output_false:0" + input: "random_flip_left_right/ReverseV2:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\340\000\000\000\340\000\000\000\003\000\000\000" + } + } + } + } + node_def { + name: "Reshape_1" + op: "Reshape" + input: "random_flip_left_right/Merge:output:0" + input: "Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape_2" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:0" + input: "Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast_1" + op: "Cast" + input: "Reshape_2:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_INT64 + } + } + } + node_def { + name: "sub_1/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "sub_1" + op: "Sub" + input: "Cast_1:y:0" + input: "sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "Reshape_1" + value: "Reshape_1:output:0" + } + ret { + key: "sub_1" + value: "sub_1:z:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_5fa5e1f4" + output_arg { + name: "PrefetchDataset_1" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "TensorSliceDataset/MatchingFiles" + op: "MatchingFiles" + input: "TensorSliceDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/MatchingFiles:filenames:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles" + op: "MatchingFiles" + input: "ShuffleDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "ShuffleDataset/Shape" + op: "Shape" + input: "ShuffleDataset/MatchingFiles:filenames:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice" + op: "StridedSlice" + input: "ShuffleDataset/Shape:output:0" + input: "ShuffleDataset/strided_slice/stack:output:0" + input: "ShuffleDataset/strided_slice/stack_1:output:0" + input: "ShuffleDataset/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "ShuffleDataset/Maximum/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/Maximum" + op: "Maximum" + input: "ShuffleDataset/strided_slice:output:0" + input: "ShuffleDataset/Maximum/y:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "TensorSliceDataset:handle:0" + input: "ShuffleDataset/Maximum:z:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ShuffleDataset_1/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed2_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1" + op: "ShuffleDataset" + input: "ShuffleDataset:handle:0" + input: "ShuffleDataset_1/buffer_size:output:0" + input: "ShuffleDataset_1/seed_1:output:0" + input: "ShuffleDataset_1/seed2_1:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "ShuffleDataset_1:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/cycle_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/block_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/sloppy" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: true + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/buffer_output_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/prefetch_input_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset" + op: "ParallelInterleaveDataset" + input: "RepeatDataset:handle:0" + input: "ParallelInterleaveDataset/cycle_length:output:0" + input: "ParallelInterleaveDataset/block_length:output:0" + input: "ParallelInterleaveDataset/sloppy:output:0" + input: "ParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_91295dea" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ShuffleDataset_2/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed2_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2" + op: "ShuffleDataset" + input: "ParallelInterleaveDataset:handle:0" + input: "ShuffleDataset_2/buffer_size_1:output:0" + input: "ShuffleDataset_2/seed_2:output:0" + input: "ShuffleDataset_2/seed2_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ParallelMapDataset/num_parallel_calls" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 64 + } + } + } + } + node_def { + name: "ParallelMapDataset" + op: "ParallelMapDataset" + input: "ShuffleDataset_2:handle:0" + input: "ParallelMapDataset/num_parallel_calls:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_74b6b15c" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "PrefetchDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "PrefetchDataset" + op: "PrefetchDataset" + input: "ParallelMapDataset:handle:0" + input: "PrefetchDataset/buffer_size_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "PrefetchDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + node_def { + name: "PrefetchDataset_1/buffer_size_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "PrefetchDataset_1" + op: "PrefetchDataset" + input: "FilterDataset:handle:0" + input: "PrefetchDataset_1/buffer_size_3:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 64 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: 64 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + ret { + key: "PrefetchDataset_1" + value: "PrefetchDataset_1:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_5fa5e1f4"; + std::function mutate_proto_func = + [dataset_name, file_path](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found = false; + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() != "TensorSliceDataset/MatchingFiles/pattern" && + node_def.name() != "ShuffleDataset/MatchingFiles/pattern") + continue; + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found = true; + DCHECK_EQ(node_def.attr().at("value").tensor().string_val(0), + "$(DATA_DIR)"); + VLOG(1) << "Setting the value of node_def " + "TensorSliceDataset/MatchingFiles/pattern to " + << file_path; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(file_path); + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +} + +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads an MNIST file dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateMNISTDatasetFunctions( + const char* file_path, int batch_size, std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_521bfd08" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "truediv" + type: DT_FLOAT + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "DecodeRaw:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 784 + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "Cast:y:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "truediv/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 255.0 + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "Reshape:output:0" + input: "truediv/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "truediv" + value: "truediv:z:0" + } + } + function { + signature { + name: "tf_map_func_9a08860d" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "ToInt32" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "DecodeRaw:output:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_UINT8 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ToInt32" + op: "Cast" + input: "Reshape:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + ret { + key: "ToInt32" + value: "ToInt32:y:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_2451e43a" + output_arg { + name: "FilterDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "FixedLengthRecordDataset/filenames" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-images-idx3-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/header_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/record_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 784 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/footer_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset/filenames:output:0" + input: "FixedLengthRecordDataset/header_bytes:output:0" + input: "FixedLengthRecordDataset/record_bytes:output:0" + input: "FixedLengthRecordDataset/footer_bytes:output:0" + input: "FixedLengthRecordDataset/buffer_size:output:0" + } + node_def { + name: "MapDataset" + op: "MapDataset" + input: "FixedLengthRecordDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_521bfd08" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/filenames_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-labels-idx1-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/header_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/record_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/footer_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset_1/filenames_1:output:0" + input: "FixedLengthRecordDataset_1/header_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/record_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/buffer_size_1:output:0" + } + node_def { + name: "MapDataset_1" + op: "MapDataset" + input: "FixedLengthRecordDataset_1:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_9a08860d" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + } + node_def { + name: "ZipDataset" + op: "ZipDataset" + input: "MapDataset:handle:0" + input: "MapDataset_1:handle:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "CacheDataset/filename" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "CacheDataset" + op: "CacheDataset" + input: "ZipDataset:handle:0" + input: "CacheDataset/filename:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "CacheDataset:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "ShuffleDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 50000 + } + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "RepeatDataset:handle:0" + input: "ShuffleDataset/buffer_size_2:output:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "ShuffleDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + ret { + key: "FilterDataset" + value: "FilterDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_2451e43a"; + std::function mutate_proto_func = + [dataset_name, file_path, batch_size](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found_file_path = false, found_batch_size = false; + // `node_def` may be mutated. + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() == "FixedLengthRecordDataset/filenames" || + node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_file_path = true; + // Replace $(DATA_DIR)/foo with /foo + // TODO(hongm): Use StringPiece manipulation for better efficiency. + const std::string cur_value = + node_def.attr().at("value").tensor().string_val(0); + const std::string pattern = "$(DATA_DIR)"; + DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); + const std::string new_value = + file_path + cur_value.substr(pattern.length()); + VLOG(1) << "Setting the value of node_def " << node_def.name() + << " to " << new_value; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(new_value); + } else if (node_def.name() == "BatchDataset/batch_size" || + node_def.name() == "FilterDataset/batch_size_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_batch_size = true; + // Replace $(BATCH_SIZE) with `batch_size` + DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); + VLOG(1) << "Setting the batch size attr value of node_def " + << node_def.name() << " to " << batch_size; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_int64_val(); + tensor->add_int64_val(batch_size); + } + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found_file_path); + DCHECK(found_batch_size); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +} + +// Adds the input functions to `graph`. On success, returns the created +// IteratorGetNext node. +static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph( + const std::vector& funcs, const std::string& dataset_name, + const std::vector& output_types, + const std::vector& output_shapes, + TF_Graph* graph, TF_Status* status) { + DCHECK(!dataset_name.empty()); + for (auto& func : funcs) { + TF_GraphCopyFunction(graph, func.get(), /*gradient*/ nullptr, status); + if (!status->status.ok()) { + return nullptr; + } + } + + tensorflow::mutex_lock c(graph->mu); + + tensorflow::NameAttrList func; + func.set_name(dataset_name); + // Run the iterator node on CPU. + Node* oneshot_iterator_node; + tensorflow::Status s = NodeBuilder("OneShotIterator", "OneShotIterator") + .Device("/device:CPU:0") + .Attr("container", "") + .Attr("dataset_factory", func) + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Attr("shared_name", "") + .Finalize(&graph->graph, &oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + // Run the iterator node on CPU. + Node* getnext_node; + s = NodeBuilder("IteratorGetNext", "IteratorGetNext") + .Input(oneshot_iterator_node) + .Device("/device:CPU:0") + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Finalize(&graph->graph, &getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); + return ToTF_Operation(getnext_node); +} + +TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, + TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + UniqueFuncPtr result_func = CreateFakeDatasetFunction(&dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector funcs; + funcs.push_back(std::move(result_func)); + std::vector output_shape_list; + output_shape_list.push_back(tensorflow::TensorShapeProto()); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT}, output_shape_list, graph, + status); + if (!status->status.ok()) { + return nullptr; + } + + return getnext_node; +} + +TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + const auto& funcs = + is_mnist + ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, + status) + : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector output_shape_list; + // batch_size X 224 X 224 X 3 + auto image_shape = tensorflow::TensorShapeProto(); + image_shape.add_dim()->set_size(batch_size); + if (is_mnist) { + image_shape.add_dim()->set_size(784); + } else { + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(3); + } + output_shape_list.push_back(image_shape); + + // batch_size + auto label_shape = tensorflow::TensorShapeProto(); + label_shape.add_dim()->set_size(batch_size); + output_shape_list.push_back(label_shape); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT, tensorflow::DT_INT32}, + output_shape_list, graph, status); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::mutex_lock c(graph->mu); + VLOG(1) << "The extended graph: " + << graph->graph.ToGraphDefDebug().DebugString(); + + return getnext_node; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 5a7b007e40aa199889b2d00b2bde5976c19e2966..ebcec8176b63f9a91c847ebe96fba3ff023fc599 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -25,6 +25,7 @@ limitations under the License. // Experimental C API for TensorFlow. // // The API here is subject to changes in the future. +// -------------------------------------------------------------------------- // Macro to control visibility of exported symbols in the shared library (.so, // .dylib, .dll). @@ -59,6 +60,53 @@ extern "C" { TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Initializes TPU system. Must be called exactly once before TF_SessionRun() is +// called on a TPU graph. +// +// The session graph must contain a node named ConfigureDistributedTPU. +// TODO(b/74774824): Improve the API on initializing TPU system. +TF_CAPI_EXPORT extern void TF_InitializeTPU(TF_Session* session, + TF_Status* status); + +// Shuts down TPU system. For any `session` where TF_InitializeTPU() has +// been successfully called, this call must be made exactly once before the +// session is closed. +// The session graph must contain a node named ShutdownDistributedTPU. +TF_CAPI_EXPORT extern void TF_ShutdownTPU(TF_Session* session, + TF_Status* status); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Creates a stack of data set + iterator nodes, currently hard-coded to return +// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, +// returns the IteratorGetNext node, which caller can run or feed into an node. +// +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets( + TF_Graph* graph, TF_Status* status); + +// Similar to the above API, except that the returned iterator reads the +// file based dataset from `file_path`. +// If `is_mnist` is 0, the dataset corresponds to ImageNet. +// The iterators outputs 2 tensors: +// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or +// `batch_size` X 224 X 224 X 3 otherwise. +// - An int32 tensor of shape `batch_size` +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30fcfd401d9d634962d64aaa3bf348de91f2ecae --- /dev/null +++ b/tensorflow/c/c_api_experimental_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void TestFakeIteratorStack() { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + const float base_value = 42.0; + for (int i = 0; i < 3; ++i) { + csession.SetOutputs({get_next}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); + ASSERT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); + float* output_contents = static_cast(TF_TensorData(out)); + ASSERT_EQ(base_value + i, *output_contents); + } + + // This should error out since we've exhausted the iterator. + csession.Run(s); + ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } + +TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + const string file_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); + VLOG(1) << "data file path is " << file_path; + const int batch_size = 64; + TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets( + graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + // The two output tensors should look like: + // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) + // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Running iter " << i; + csession.SetOutputs({{get_next, 0}, {get_next, 1}}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + { + TF_Tensor* image = csession.output_tensor(0); + ASSERT_TRUE(image != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); + // Confirm shape is 224 X 224 X 3 + ASSERT_EQ(4, TF_NumDims(image)); + ASSERT_EQ(batch_size, TF_Dim(image, 0)); + ASSERT_EQ(224, TF_Dim(image, 1)); + ASSERT_EQ(224, TF_Dim(image, 2)); + ASSERT_EQ(3, TF_Dim(image, 3)); + ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, + TF_TensorByteSize(image)); + } + + { + TF_Tensor* label = csession.output_tensor(1); + ASSERT_TRUE(label != nullptr); + ASSERT_EQ(TF_INT32, TF_TensorType(label)); + ASSERT_EQ(1, TF_NumDims(label)); + ASSERT_EQ(batch_size, TF_Dim(label, 0)); + ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); + } + } + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index e885a699274cfae04d5a17c736da1acfddcc7b3b..95652a11378d6276b5ba6540a07baa15aa77cc1c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -84,19 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // The keys of this map are all the active sessions using this graph. - // Each value is the current "runnability" status of the corresponding - // session. Under normal conditions all statuses are Status::OK(), but - // if some operation is mutated after it was run by a session (this - // is detected in RecordMutation function), that session is no longer - // safe to run. Its status will contain the error that will be returned - // to the user, should she try running this session. + // The keys of this map are all the active sessions using this graph. Each + // value records whether the graph has been mutated since the corresponding + // session has been run (this is detected in RecordMutation function). If the + // string is empty, no mutation has occurred. Otherwise the string is a + // description of the mutation suitable for returning to the user. // // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. // TF_Graph may only / must be deleted when // sessions.size() == 0 && delete_requested == true - tensorflow::gtl::FlatMap sessions + // + // TODO(b/74949947): mutations currently trigger a warning instead of a bad + // status, this should be reverted when possible. + tensorflow::gtl::FlatMap sessions GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 028f146be31790b211e546978302e81afe26b231..ca80db23ed3ccbbdc49c61db6cd03ff735470512 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -53,7 +53,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 22f77e7b874a13b3b6e0fbe981b4188c634db439..f3b28c1708129d39e451d927a89c0d10e2193b63 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -94,18 +94,22 @@ TF_Tensor* FloatTensor(float v) { // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype, TF_Operation** op) { + TF_DataType dtype, const std::vector& dims, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, - TF_DataType dtype) { + TF_DataType dtype, const std::vector& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, dtype, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d87c57fd5193129665ca65761872a38131ee532b..cd19cf8d624d9b914b61132f93d918b046cdbd30 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -48,7 +48,8 @@ TF_Tensor* FloatTensor(float v); TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name = "feed", - TF_DataType dtype = TF_INT32); + TF_DataType dtype = TF_INT32, + const std::vector& dims = {}); TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3046d9064a6d4b39cd8a7209d7f20e1e779c2847..a2d96357ac8a55be7fe03bf58e33ff1733967dd1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -27,6 +27,14 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:execute_node", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -54,12 +62,17 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) @@ -94,6 +107,7 @@ tf_cuda_library( "//conditions:default": [ "//tensorflow/c:c_api", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 0811bd363f2aecedc94488b6ee87fac3f4b2af14..c96a38dec3ed7bcbbd77415ec3b158390def797e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -32,6 +32,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/node_def_util.h" @@ -71,18 +74,6 @@ std::atomic_int_fast64_t func_id_generator(0); } // namespace -TFE_ContextDevicePlacementPolicy PlacementPolicy( - bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy) { - if (!soft_placement) { - return original_policy; - } - if (original_policy == TFE_DEVICE_PLACEMENT_EXPLICIT || - original_policy == TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) { - return TFE_DEVICE_PLACEMENT_SILENT; - } - return original_policy; -} - extern "C" { TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } @@ -104,19 +95,7 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { - { - tensorflow::mutex_lock l(ctx->async_map_mu); - ctx->thread_local_async[std::this_thread::get_id()] = async; - } - if (async) { - ctx->executor.EnableAsync(); - } else { - // TODO(agarwal): Currently we add a wait here to handle cases where a sync - // op has a control dependency on an async op, and the latter has not - // executed yet. This wait can be removed by storing all the control inputs - // and waiting for them when executing ops. - status->status = ctx->executor.WaitForAllPendingNodes(); - } + status->status = ctx->context.SetAsyncForThread(async); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -133,60 +112,47 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::DeviceMgr(devices)); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context(*opts, std::move(device_mgr), r); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, std::move(device_mgr), r); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.WaitForAllPendingNodes(); - { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); - } - ctx->rendezvous->Unref(); delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->device_manager->ListDeviceAttributes(&list->response); + ctx->context.device_mgr()->ListDeviceAttributes(&list->response); return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); -} +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - ctx->thread_local_policies[std::this_thread::get_id()] = policy; + ctx->context.SetThreadLocalDevicePlacementPolicy( + static_cast(policy)); } // Note: this function looks up a thread local policy. So it should be called in // the appropriate client thread. In particular, in async mode, it may not be -// safe to call this function from the async TFE_Executor threads. +// safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - auto policy_map_it = - ctx->thread_local_policies.find(std::this_thread::get_id()); - if (policy_map_it != ctx->thread_local_policies.end()) { - return policy_map_it->second; - } - return ctx->policy; + return static_cast( + ctx->context.GetDevicePlacementPolicy()); } void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.WaitForAllPendingNodes(); + status->status = ctx->context.AsyncWait(); } void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->executor.status(); + status->status = ctx->context.GetStatus(); } void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->executor.ClearError(); + ctx->context.ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { @@ -198,29 +164,32 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { DCHECK(h); - h->Unref(); + if (h->handle) { + h->handle->Unref(); + } + delete h; } TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->dtype); + return static_cast(h->handle->dtype); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dims(); } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dim_size(dim_index); } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; - status->status = h->OpDevice(&d); + status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } @@ -230,98 +199,28 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; + tensorflow::TensorHandle* h_cpu = nullptr; if (!IsCPU(d)) { - TF_SetStatus(status, TF_UNIMPLEMENTED, - tensorflow::strings::StrCat( - "TFE_TensorHandle can be resolved iff it is on CPU (this " - "handle is on ", - d->name(), - "). Consider using TFE_TensorHandleCopyToDevice to get a " - "copy of the tensor on CPU") - .c_str()); - return nullptr; - } - return tensorflow::TF_TensorFromTensor(*t, status); -} -} // extern "C" - -namespace { - -tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - tensorflow::Device* dstd, - TFE_TensorHandle** output) { - const tensorflow::Tensor* src = nullptr; - tensorflow::Device* srcd = nullptr; - // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept - // nullptr. - tensorflow::Device* src_opd = nullptr; - TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); - if (srcd == nullptr) srcd = ctx->devices[0]; - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(*src, dstd, dstd); - return tensorflow::Status::OK(); - } - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - return tensorflow::errors::InvalidArgument( - "Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), " to device ", - DeviceName(dstd), "."); - } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); - return tensorflow::Status::OK(); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; - } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - TF_RETURN_IF_ERROR(srcd->Sync()); - tensorflow::Notification n; - tensorflow::Status status; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [&status, &n](const tensorflow::Status& s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - if (status.ok()) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); + TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); + if (h_cpu != nullptr) { + h_cpu->Unref(); } - return status; + return retval; } -} // namespace +} // extern "C" extern "C" { @@ -332,8 +231,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - tensorflow::mutex_lock l(ctx->functions_mu); - if (ctx->func_lib_def.Find(name) != nullptr) { + if (ctx->context.FindFunctionByName(name)) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -346,15 +244,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = op->ctx->device_manager->LookupDevice(device_name, &d); - if (!status->status.ok()) return; + status->status = op->ctx->context.FindDeviceByName(device_name, &d); } op->device = d; } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->devices[0] : op->device; + (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; return device->name().c_str(); } @@ -367,19 +264,8 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - if (op->device == nullptr) { - // Questionable heuristic ... - // - If a device was explicitly set on the op, always use that. - // - If not, place on the first non-host device seen. - tensorflow::Device* d = nullptr; - // TODO(agarwal): This call may block if h is not ready. Avoid this if - // possible. - status->status = h->Device(&d); - if (!status->status.ok()) return; - if (!IsCPU(d)) op->device = d; - } - h->Ref(); - op->inputs.push_back(h); + h->handle->Ref(); + op->inputs.push_back(h->handle); op->attrs.NumInputs(op->inputs.size()); } @@ -545,10 +431,39 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, namespace { +// Initializes the step stats if needed. +void MaybeInitializeStepStats(tensorflow::StepStats* step_stats, + tensorflow::EagerContext* ctx) { + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices()->size()) { + int device_idx = step_stats->dev_stats_size(); + auto* dev_stats = step_stats->add_dev_stats(); + dev_stats->set_device(ctx->devices()->at(device_idx)->name()); + } +} + +int StepStatsDeviceIndex(tensorflow::StepStats* step_stats, + tensorflow::EagerContext* ctx, + tensorflow::Device* device) { + // Find the current device's index. + if (device == nullptr) { + device = ctx->HostCPU(); + } + for (int i = 0; i < ctx->devices()->size(); ++i) { + if (ctx->devices()->at(i) == device || + ctx->devices()->at(i)->name() == device->name()) { + return i; + } + } + // TODO(apassos) do not fall back to host CPU if device is unknown. + return 0; +} + tensorflow::Status ValidateInputTypeAndPlacement( - TFE_Context* ctx, tensorflow::Device* host_device, - tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel) { + tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) { + tensorflow::Device* host_device = ctx->HostCPU(); const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -557,14 +472,14 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; - TFE_TensorHandle* handle = op->inputs[i]; + tensorflow::TensorHandle* handle = op->inputs[i]; tensorflow::Device* handle_device = nullptr; TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = handle_device == nullptr ? host_device : handle_device; if (expected_device != actual_device) { - switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { - case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: + switch (ctx->GetDevicePlacementPolicy()) { + case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. if (handle->dtype == tensorflow::DT_INT32) { @@ -573,7 +488,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( break; } TF_FALLTHROUGH_INTENDED; - case TFE_DEVICE_PLACEMENT_EXPLICIT: + case tensorflow::DEVICE_PLACEMENT_EXPLICIT: return tensorflow::errors::InvalidArgument( "Tensors on conflicting devices:" " cannot compute ", @@ -581,11 +496,13 @@ tensorflow::Status ValidateInputTypeAndPlacement( expected_device->name(), " but is actually on ", actual_device->name(), " (operation running on ", op_device->name(), ")", - " Tensors can be copied explicitly using .gpu() or .cpu()," - " or transparently copied by using tfe.enable_eager_execution(" - "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" + " Tensors can be copied explicitly using .gpu() or .cpu() " + "methods," + " or transparently copied by using tf.enable_eager_execution(" + "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors " + "between devices" " may slow down your model"); - case TFE_DEVICE_PLACEMENT_WARN: + case tensorflow::DEVICE_PLACEMENT_WARN: LOG(WARNING) << "before computing " << op->name << " input #" << i << " was expected to be on " << expected_device->name() << " but is actually on " << actual_device->name() @@ -593,16 +510,27 @@ tensorflow::Status ValidateInputTypeAndPlacement( << "). This triggers a copy which can be a performance " "bottleneck."; break; - case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing. break; } // We are only here if the policy is warn or silent copies, so we should // trigger a copy. - TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - handle, ctx, expected_device->name().c_str(), s); - tensorflow::Status status = s->status; - TF_DeleteStatus(s); + auto pre_time = tensorflow::Env::Default()->NowMicros(); + tensorflow::TensorHandle* copied_tensor = nullptr; + tensorflow::Status status = tensorflow::EagerCopyToDevice( + handle, ctx, expected_device->name().c_str(), &copied_tensor); + if (run_metadata != nullptr) { + auto* step_stats = run_metadata->mutable_step_stats(); + MaybeInitializeStepStats(step_stats, ctx); + // Record the sending on the source device for now. + int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device); + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + auto* node_stats = dev_stats->add_node_stats(); + node_stats->set_node_name("_Send"); + node_stats->set_all_start_micros(pre_time); + node_stats->set_op_end_rel_micros( + tensorflow::Env::Default()->NowMicros() - pre_time); + } if (!status.ok()) { if (copied_tensor != nullptr) copied_tensor->Unref(); return tensorflow::errors::Internal( @@ -629,7 +557,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, TFE_Context* ctx, TF_Status* status) { tensorflow::DeviceSet ds; - for (tensorflow::Device* d : ctx->devices) { + for (tensorflow::Device* d : *ctx->context.devices()) { ds.AddDevice(d); } tensorflow::DeviceTypeVector final_devices; @@ -643,7 +571,7 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, "Could not find valid device for node ", ndef.DebugString()); return nullptr; } - for (tensorflow::Device* d : ctx->devices) { + for (tensorflow::Device* d : *ctx->context.devices()) { if (d->device_type() == final_devices[0].type_string()) { return d; } @@ -653,186 +581,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, return nullptr; } -tensorflow::Status Execute( - TFE_Context* ctx, tensorflow::Device* device, - const tensorflow::gtl::InlinedVector& op_inputs, - tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, - TFE_TensorHandle** retvals, int num_retvals) { - if (!ctx->soft_placement && device == nullptr) { - // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU - device = ctx->devices[0]; - } - - if (device == nullptr) { - // TODO(apassos) debug how the assignment below might return a different - // device from the one requested above. - device = kernel->device(); - } - - std::vector outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; - output_memory_types = &kernel->kernel()->output_memory_types(); - std::vector inputs(op_inputs.size()); - for (int i = 0; i < op_inputs.size(); ++i) { - const tensorflow::Tensor* input_tensor = nullptr; - TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor)); - inputs[i] = *input_tensor; - } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - // TODO(agarwal): change Run to take vector of handles ? - TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats)); - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices.size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->devices.size(); ++i) { - if (ctx->devices[i] == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; - } - } - DCHECK_EQ(num_retvals, outputs.size()); - tensorflow::Device* op_device = IsCPU(device) ? nullptr : device; - for (int i = 0; i < num_retvals; ++i) { - tensorflow::Device* d = op_device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; - } - if (retvals[i] == nullptr) { - retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device); - } else { - retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); - } - } - return tensorflow::Status::OK(); -} - -// TODO(agarwal): move TFE_Executor and TFE_Node related code to a separate -// file. -class ExecuteNode : public TFE_Node { - public: - ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel, - tensorflow::NodeExecStats* maybe_stats, - const tensorflow::DataTypeVector& output_dtypes, - TFE_TensorHandle** retvals, int num_retvals) - : TFE_Node(op->ctx->executor.NextId()), - ctx_(op->ctx), - op_device_(op->device), - inputs_(op->inputs), - kernel_(kernel), - maybe_stats_(maybe_stats), - retvals_(num_retvals) { - for (auto handle : inputs_) { - handle->Ref(); - } - TFE_Context* ctx = op->ctx; - for (int i = 0; i < num_retvals; ++i) { - TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx); - h->Ref(); - retvals[i] = h; - retvals_[i] = h; - } - } - - ~ExecuteNode() override { - for (auto handle : inputs_) { - handle->Unref(); - } - for (auto handle : retvals_) { - handle->Unref(); - } - } - - tensorflow::Status Run() override { - const tensorflow::Status status = - Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), - retvals_.begin(), retvals_.size()); - if (status.ok()) { - return status; - } else { - return tensorflow::Status( - status.code(), - tensorflow::strings::StrCat("Got error, \"", status.error_message(), - "\" while executing kernel ", - kernel_->kernel()->def().DebugString())); - } - } - - private: - TFE_Context* ctx_; - tensorflow::Device* op_device_; - tensorflow::gtl::InlinedVector inputs_; - tensorflow::KernelAndDevice* kernel_; - std::unique_ptr maybe_stats_; - tensorflow::gtl::InlinedVector retvals_; -}; - -class CopyToDeviceNode : public TFE_Node { - public: - CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, - TFE_Context* ctx) - : TFE_Node(ctx->executor.NextId()), - src_(src), - dstd_(dstd), - ctx_(ctx), - dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) { - src_->Ref(); - dst_->Ref(); - } - - ~CopyToDeviceNode() override { - src_->Unref(); - dst_->Unref(); - } - - tensorflow::Status Run() override { - TFE_TensorHandle* temp = nullptr; - TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); - const tensorflow::Tensor* tensor = nullptr; - tensorflow::Device* device = nullptr; - tensorflow::Device* op_device = nullptr; - tensorflow::Status status = - temp->TensorAndDevice(&tensor, &device, &op_device); - // `temp` is a ready handle. So the following call should return OK. - TF_DCHECK_OK(status) << status.error_message(); - DCHECK(tensor); - dst_->SetTensorAndDevice(*tensor, device, op_device); - temp->Unref(); - return tensorflow::Status::OK(); - } - - TFE_TensorHandle* dst() { return dst_; } - - private: - TFE_TensorHandle* src_; - tensorflow::Device* dstd_; - TFE_Context* ctx_; - TFE_TensorHandle* dst_; -}; #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a @@ -861,8 +609,7 @@ const tensorflow::FunctionDef* OpToFunction( TFE_Context* ctx = op->ctx; const tensorflow::OpRegistrationData* op_data; { - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + status->status = ctx->context.FindFunctionOpData(op->name, &op_data); if (!status->status.ok()) { return nullptr; } @@ -958,10 +705,9 @@ const tensorflow::FunctionDef* OpToFunction( } VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(fdef); + status->status = ctx->context.AddFunctionDef(fdef); if (!status->status.ok()) return nullptr; - const auto ret = ctx->func_lib_def.Find(signature->name()); + const auto ret = ctx->context.FindFunctionDef(signature->name()); DCHECK(ret != nullptr); return ret; } @@ -980,8 +726,7 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { const tensorflow::FunctionDef* fdef; { - tensorflow::tf_shared_lock l(op->ctx->functions_mu); - fdef = op->ctx->func_lib_def.Find(op->name); + fdef = op->ctx->context.FindFunctionDef(op->name); } std::vector const_input_types; std::vector arg_input_types; @@ -1008,7 +753,7 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { // Since input param reordering may have occurred between `op` and `launch_op` // via `op_input_to_func_input`, adjust the actual inputs accordingly. launch_op->inputs = op->inputs; - for (TFE_TensorHandle* h : launch_op->inputs) { + for (tensorflow::TensorHandle* h : launch_op->inputs) { h->Ref(); } if (!op_input_to_func_input.empty()) { @@ -1058,7 +803,7 @@ extern "C" { void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { TFE_Context* ctx = op->ctx; - status->status = ctx->executor.status(); + status->status = ctx->context.GetStatus(); if (!status->status.ok()) { return; } @@ -1079,10 +824,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, tensorflow::Device* input_op_device = nullptr; status->status = op->inputs[i]->OpDevice(&input_op_device); if (!status->status.ok()) return; + VLOG(2) << "for op " << op->name << " input " << i << " " + << tensorflow::DataTypeString(op->inputs[i]->dtype) << " " + << (input_op_device == nullptr ? "cpu" : input_op_device->name()) + << " " << (op->device == nullptr ? "cpu" : op->device->name()); if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && - input_op_device != op->device) { + (input_op_device != op->device || input_op_device == nullptr)) { tensorflow::Device* d = - input_op_device == nullptr ? ctx->devices[0] : input_op_device; + input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; VLOG(1) << "Changing device of operation " << op->name << " to " << d->name() << " because input #" << i << " is a resource in this device."; @@ -1090,40 +839,32 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } } tensorflow::Device* device = op->device; - if (!ctx->soft_placement && device == nullptr) { - // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU - device = ctx->devices[0]; - } tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); - tensorflow::KernelAndDevice* kernel; - { - tensorflow::tf_shared_lock l(ctx->cache_mu); - kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); - } + tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - if (ctx->soft_placement && device == nullptr) { + if (device == nullptr) { device = SelectDevice(ndef, ctx, status); if (!status->status.ok()) { return; } } CHECK(device != nullptr); - if (ctx->log_device_placement) { + if (ctx->context.LogDevicePlacement()) { LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); + kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); + status->status = tensorflow::KernelAndDevice::Init( + ndef, ctx->context.func_lib(device), kernel); if (!status->status.ok()) { delete kernel; return; @@ -1131,7 +872,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // Update output_dtypes inside `kernel`. const tensorflow::OpDef* op_def = nullptr; const tensorflow::FunctionDef* function_def = - ctx->func_lib_def.Find(ndef.op()); + ctx->context.FuncLibDef()->Find(ndef.op()); if (function_def != nullptr) { op_def = &(function_def->signature()); } @@ -1147,8 +888,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (!status->status.ok()) { return; } - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); + ctx->context.AddKernelToCache(cache_key, kernel); } const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); const int output_dtypes_size = output_dtypes.size(); @@ -1166,11 +906,13 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // device from the one requested above. device = kernel->device(); } - status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device, - op, kernel->kernel()); + status->status = ValidateInputTypeAndPlacement( + &ctx->context, device, op, kernel->kernel(), + ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto() + : nullptr); if (!status->status.ok()) return; std::unique_ptr maybe_stats; - if (ctx->should_store_metadata.load()) { + if (ctx->context.ShouldStoreMetadata()) { maybe_stats.reset(new tensorflow::NodeExecStats); maybe_stats->set_node_name(op->name); maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); @@ -1178,21 +920,34 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - if (ctx->Async()) { + if (ctx->context.Async()) { // Note that for async mode, execution order will make sure that all // input handles are ready before executing them. // TODO(agarwal): Consider executing "cheap" kernels inline for performance. - TFE_Node* node = new ExecuteNode(op, kernel, maybe_stats.release(), - output_dtypes, retvals, *num_retvals); - ctx->executor.Add(node); + tensorflow::gtl::InlinedVector handle_retvals( + *num_retvals); + tensorflow::uint64 id = op->ctx->context.NextId(); + for (int i = 0; i < *num_retvals; ++i) { + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context); + retvals[i] = new TFE_TensorHandle(h); + handle_retvals[i] = h; + } + tensorflow::EagerNode* node = new tensorflow::ExecuteNode( + id, &op->ctx->context, op->device, op->inputs, kernel, + maybe_stats.release(), output_dtypes, handle_retvals); + ctx->context.ExecutorAdd(node); } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. + std::vector handle_retvals(*num_retvals, + nullptr); + status->status = tensorflow::EagerExecute( + &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = nullptr; + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } - status->status = Execute(op->ctx, op->device, op->inputs, kernel, - maybe_stats.get(), retvals, *num_retvals); } } @@ -1200,26 +955,13 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - status->status = ctx->executor.status(); - if (!status->status.ok()) { - return nullptr; - } - tensorflow::Device* dstd = ctx->devices[0]; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->device_manager->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - if (ctx->Async()) { - // Note that `h` may not be currently ready. However execution order will - // make sure that `h` is ready before the copy is actually done. - CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); - ctx->executor.Add(node); - return node->dst(); - } else { - TFE_TensorHandle* output = nullptr; - status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); - return output; + tensorflow::TensorHandle* handle; + status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + device_name, &handle); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); } + return nullptr; } void TFE_ContextAddFunctionDef(TFE_Context* ctx, @@ -1231,24 +973,20 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function_def); + status->status = ctx->context.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); + status->status = ctx->context.AddFunctionDef(function->fdef); } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); + ctx->context.SetShouldStoreMetadata(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); + ctx->context.SetShouldStoreMetadata(false); } } // extern "C" @@ -1262,7 +1000,7 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -1277,9 +1015,9 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { TFE_ContextAsyncWait(ctx, status); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(ctx->metadata_mu); - status->status = MessageToBuffer(ctx->run_metadata, buf); - ctx->run_metadata.Clear(); + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); + ctx->context.RunMetadataProto()->Clear(); } namespace { @@ -1353,207 +1091,9 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } } // namespace tensorflow -TFE_Node::TFE_Node(tensorflow::uint64 id) : id(id) {} - -TFE_Executor::~TFE_Executor() { - tensorflow::mutex_lock l(node_queue_mutex_); - thread_done_ = true; - nodes_pending_.notify_all(); -} - -tensorflow::uint64 TFE_Executor::NextId() { - tensorflow::mutex_lock l(next_id_mutex_); - return next_id_++; -} - -void TFE_Executor::EnableAsync() { - tensorflow::mutex_lock l(node_queue_mutex_); - if (thread_ == nullptr) { - thread_.reset(tensorflow::Env::Default()->StartThread( - tensorflow::ThreadOptions(), "eager_async_executor", - std::bind(&TFE_Executor::Run, this))); - } -} - -void TFE_Executor::Add(TFE_Node* node) { - tensorflow::mutex_lock l(node_queue_mutex_); - DCHECK(thread_) << "EnableAsync should have been called before Add"; - if (!status_.ok()) { - delete node; - return; - } - int qlen = node_queue_.size(); - if (qlen > 0) { - if (node_queue_.back()->id >= node->id) { - status_ = tensorflow::errors::InvalidArgument( - "Inserting TFE_Node with non-increasing ids:", node_queue_.back()->id, - " vs ", node->id); - delete node; - return; - } - node_queue_.push(node); - } else { - node_queue_.push(node); - nodes_pending_.notify_all(); - } -} - -tensorflow::Status TFE_Executor::WaitFor(tensorflow::uint64 node_id) { - return WaitImpl(false, node_id); -} - -tensorflow::Status TFE_Executor::WaitForAllPendingNodes() { - return WaitImpl(true, 0); -} - -tensorflow::Status TFE_Executor::WaitImpl(bool wait_all, - tensorflow::uint64 node_id) { - tensorflow::condition_variable cond; - tensorflow::mutex_lock l(node_queue_mutex_); - // Don't wait if an error is already set. - if (!status_.ok()) return status_; - if (node_queue_.empty()) return tensorflow::Status::OK(); - if (wait_all) { - node_id = node_queue_.back()->id; - } else if (node_id < node_queue_.front()->id) { - // Note that we are relying on the ops being dispatched sequentially from - // the queue. - return tensorflow::Status::OK(); - } - node_done_notifications_.insert(std::make_pair(node_id, &cond)); - cond.wait(l); - // Note that we could be woken up if an error occurs, even though the node has - // not actually executed. - return status_; -} - -void TFE_Executor::ClearError() { - tensorflow::mutex_lock l(node_queue_mutex_); - if (status_.ok()) return; - // If an error was set, node_done_notifications_ and node_queue_ should have - // been cleared, and no new entries should have been added since. - DCHECK(node_done_notifications_.empty()); - DCHECK(node_queue_.empty()); - status_ = tensorflow::Status::OK(); - nodes_pending_.notify_all(); -} - -tensorflow::Status TFE_Executor::status() { - tensorflow::mutex_lock l(node_queue_mutex_); - return status_; -} - -void TFE_Executor::Run() { - while (true) { - std::unique_ptr curr_node; - { - tensorflow::mutex_lock l(node_queue_mutex_); - while (node_queue_.empty() || !status_.ok()) { - if (thread_done_) return; - nodes_pending_.wait(l); - } - curr_node.reset(node_queue_.front()); - } - tensorflow::Status status = curr_node->Run(); - const bool ok = status.ok(); - tensorflow::mutex_lock l(node_queue_mutex_); - node_queue_.pop(); - if (!ok) { - status_ = status; - // TODO(agarwal): mark all affected handles as corrupted before clearing - // this queue. - // We remove any pending ops so that we don't try to execute them if - // ClearError is called. - for (int i = 0; i < node_queue_.size(); ++i) { - delete node_queue_.front(); - node_queue_.pop(); - } - } - if (!node_done_notifications_.empty()) { - tensorflow::uint64 node_id = curr_node->id; - // Note that we notify all waiting threads in case an error has occurred. - // These calling threads are responsible for checking status_ before - // proceeding. - const auto range = ok ? node_done_notifications_.equal_range(node_id) - : make_pair(node_done_notifications_.begin(), - node_done_notifications_.end()); - for (auto it = range.first; it != range.second; ++it) { - it->second->notify_all(); - } - node_done_notifications_.erase(range.first, range.second); - } - } -} - -bool TFE_Context::Async() const { - tensorflow::mutex_lock l(async_map_mu); - return tensorflow::gtl::FindWithDefault( - thread_local_async, std::this_thread::get_id(), async_default); -} - -bool TFE_TensorHandle::IsReady() { - if (node_id == 0) return true; - tensorflow::mutex_lock l(ctx_mutex_); - return ctx_ == nullptr; -} - -tensorflow::Status TFE_TensorHandle::WaitReady() { - if (node_id == 0) return tensorflow::Status::OK(); - TFE_Executor* executor = nullptr; - { - tensorflow::mutex_lock l(ctx_mutex_); - if (ctx_ == nullptr) return tensorflow::Status::OK(); - executor = &ctx_->executor; - } - return executor->WaitFor(node_id); -} - -tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *t = &tensor_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = op_device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::TensorAndDevice( - const tensorflow::Tensor** tensor, tensorflow::Device** device, - tensorflow::Device** op_device) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *tensor = &tensor_; - *device = device_; - *op_device = op_device_; - return tensorflow::Status::OK(); -} - -void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device) { - tensorflow::mutex_lock l(ctx_mutex_); - DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " - << "on non-ready handles."; - ctx_ = nullptr; - tensor_ = tensor; - device_ = device; - op_device_ = op_device; -} TFE_Op::~TFE_Op() { - for (TFE_TensorHandle* h : inputs) { + for (tensorflow::TensorHandle* h : inputs) { h->Unref(); } } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a5029bf2115c7dac54d03b8bc6397bc63349c068..3926c22ce1f9e194b1452c796c83944d10cfdc64 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -61,17 +61,15 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( // Controls how to act when we try to run an operation on a given device but // some input tensors are not on that device. typedef enum TFE_ContextDevicePlacementPolicy { - // Running operations with input tensors on the wrong device will fail. When - // soft placement is enabled acts like TFE_DEVICE_PLACEMENT_SILENT. + // Running operations with input tensors on the wrong device will fail. TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, - // Silently copy the tensor, which has a performance cost since the - // operation will be blocked till the copy completes. + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default placement + // policy. TFE_DEVICE_PLACEMENT_SILENT = 2, - // Default placement policy which silently copies int32 tensors but not other - // dtypes. When soft placement is enabled acts like - // TFE_DEVICE_PLACEMENT_SILENT. + // Placement policy which silently copies int32 tensors but not other dtypes. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; @@ -162,7 +160,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); -// This function will block till the operation that produces `h` has completed. +// This function will block till the operation that produces `h` has +// completed. The memory returned might alias the internal memory used by +// TensorFlow. Hence, callers should not mutate this memory (for example by +// modifying the memory region pointed to by TF_TensorData() on the returned +// TF_Tensor). TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 8dba12f47b580c33041cc134c6f07a1fafff7453..05dc64f521735f944559392f470a37590e93f17c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -30,9 +30,14 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -40,261 +45,40 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" -// A unit of execution for the TFE_Executor class below. Example subclasses -// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one -// device to another. -class TFE_Node { - public: - explicit TFE_Node(tensorflow::uint64 id); - - virtual ~TFE_Node() {} - - // Runs the computation corresponding to this node and blocks till the - // execution is done. - virtual tensorflow::Status Run() = 0; - - // An id unique to the TFE_Context under which this node is created. Allocated - // monotonically. - const tensorflow::uint64 id; -}; - -// A class for handling async execution (see TFE_ContextSetAsync). -// Note that this class is thread-safe. -// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the -// device of the input handle. Fix that. -// TODO(agarwal): On error, mark all affected handles as corrupted. -// TODO(agarwal): Implement support for control dependencies. -// TODO(agarwal): Support out-of-order execution and dispatching multiple -// TFE_Node in parallel. -// TODO(agarwal): Implement optimizations over TFE_Node traces. -class TFE_Executor { - public: - ~TFE_Executor(); - - // This is called whenever async mode is enabled. Note that it may be called - // multiple times as different calling threads may switch async mode on or off - // independently. - void EnableAsync(); - - // Helper function to create monotonically increasing ids unique to this - // object. - tensorflow::uint64 NextId(); - - // Schedules `node` for execution. - // Note that Add must be called in monotonically increasing order of node->id. - void Add(TFE_Node* node); - - // Causes the caller to block till node with id `node_id` has finished - // execution. - tensorflow::Status WaitFor(tensorflow::uint64 node_id); - - // Blocks till all currently pending ops are done. - tensorflow::Status WaitForAllPendingNodes(); - - // Clears all currently set errors which re-enables async execution. - void ClearError(); - - // Returns Status based on any errors that occurred during async execution. - tensorflow::Status status(); - - private: - // Starts execution of pending TFE_Nodes. This function loops till - // thread_done_ is set to true. If any errors are encontered, these are set - // inside `status_`. The loop blocks anytime there are no pending nodes, or if - // `status_` is not ok. - void Run(); - - tensorflow::Status WaitImpl(bool wait_all, tensorflow::uint64 node_id); - - tensorflow::mutex node_queue_mutex_; - - // Used to signal that some TFE_Nodes are pending execution. - tensorflow::condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_); - - // Queue of pending TFE_Nodes. - std::queue node_queue_ GUARDED_BY(node_queue_mutex_); - - // `status_` is set based on any errors raised during execution of a TFE_Node. - // It remains set until ClearError is called. - tensorflow::Status status_ GUARDED_BY(node_queue_mutex_); - - // Map from id of a TFE_Node to condition_variables (not owned by the map). - // These condition_variables are notified and removed when that TFE_Node is - // done executing, or if an error is found in execution of any TFE_Node. - std::multimap - node_done_notifications_ GUARDED_BY(node_queue_mutex_); - - // Thread object that calls the `Run` method. Currently we use only one thread - // for executing the TFE_Nodes one-by-one. - std::unique_ptr thread_ GUARDED_BY(node_queue_mutex_); - - // Indicates that `thread_` should stop as soon as it is done executing the - // current TFE_Node. - bool thread_done_ GUARDED_BY(node_queue_mutex_) = false; - - tensorflow::mutex next_id_mutex_; - tensorflow::uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1; -}; struct TFE_ContextOptions { TF_SessionOptions session_options; // true if async execution is enabled. bool async = false; - TFE_ContextDevicePlacementPolicy policy{ - TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; }; -TFE_ContextDevicePlacementPolicy PlacementPolicy( - bool soft_placement, TFE_ContextDevicePlacementPolicy original_policy); - struct TFE_Context { - explicit TFE_Context(const TFE_ContextOptions& opts, + explicit TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, + bool async, std::unique_ptr device_mgr, tensorflow::Rendezvous* rendezvous) - : soft_placement( - opts.session_options.options.config.allow_soft_placement()), - policy(PlacementPolicy(soft_placement, opts.policy)), - device_manager(std::move(device_mgr)), - devices(device_manager->ListDevices()), - rendezvous(rendezvous), - pflr(new tensorflow::ProcessFunctionLibraryRuntime( - device_manager.get(), opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})), - log_device_placement( - opts.session_options.options.config.log_device_placement()), - async_default(opts.async) { - if (async_default) executor.EnableAsync(); - } - - const bool soft_placement; - const TFE_ContextDevicePlacementPolicy policy; - - // Note: we cannot use C++11 thread_local here as there is no concept of a - // thread-local-object-local variable in C++11. - tensorflow::mutex policy_map_mu; - std::unordered_map - thread_local_policies GUARDED_BY(policy_map_mu); - - std::unique_ptr device_manager; - // Devices owned by device_manager - const std::vector devices; - tensorflow::Rendezvous* const rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - const std::unique_ptr pflr; + : context(opts, + static_cast( + default_policy), + async, std::move(device_mgr), rendezvous) {} - tensorflow::mutex cache_mu; - std::unordered_map - kernel_cache GUARDED_BY(cache_mu); - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { - return pflr->GetFLR(d->name()); - } - - // Whether we should compute RunMetadata. - std::atomic should_store_metadata{false}; - tensorflow::mutex metadata_mu; - tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); - const bool log_device_placement; - // TFE_Executor for async execution. - TFE_Executor executor; - - // True if running in asynchronous mode. - bool Async() const; - - // True if the default value for execution mode is async. Note that this value - // can be overridden per thread based on `thread_local_async` overrides. - const bool async_default; - mutable tensorflow::mutex async_map_mu; - std::unordered_map thread_local_async - GUARDED_BY(async_map_mu); + tensorflow::EagerContext context; }; -struct TFE_TensorHandle : public tensorflow::core::RefCounted { - public: +struct TFE_TensorHandle { TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, tensorflow::Device* op_device) - : dtype(t.dtype()), - node_id(0), - tensor_(t), - device_(d), - op_device_(op_device), - ctx_(nullptr) {} + : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - TFE_Context* ctx) - : dtype(dtype), - node_id(node_id), - tensor_(dtype), - device_(nullptr), - op_device_(nullptr), - ctx_(ctx) { - DCHECK_GT(node_id, 0); - } - - ~TFE_TensorHandle() override {} - - tensorflow::Status Tensor(const tensorflow::Tensor** t); - - tensorflow::Status Device(tensorflow::Device** d); - - tensorflow::Status OpDevice(tensorflow::Device** d); - - tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor, - tensorflow::Device** device, - tensorflow::Device** op_device); - - // Note that this can be called at most once, and only on non-ready handles, - // and makes them ready. - void SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device); - - // dtype for the handle. It must be the same as t.dtype() once the handle is - // ready. - const tensorflow::DataType dtype; - - private: - // If the contents of the Tensor pointed to by this handle is yet to be - // computed by a TFE_Node, this function will block till that compuatation is - // done and the handle is "ready". - tensorflow::Status WaitReady(); - - bool IsReady(); - - // Id for the TFE_Node that will compute the value pointed to by this handle. - // If the value is 0, the handle is already ready, but not vice-versa. - const tensorflow::uint64 node_id; - - tensorflow::Tensor tensor_; - - // TODO(ashankar): device_ == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('device_' should always - // be a valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* device_; - - // Device in which the op producing this tensor was executed. Equals to - // device_ for constant tensors. - tensorflow::Device* op_device_; + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - tensorflow::mutex ctx_mutex_; + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} - // `ctx` is only guaranteed to be set if the handle is not "ready". This is - // typically true when the handle was produced during async execution. - // `ctx` object is not owned and should outlive this handle. - TFE_Context* ctx_ GUARDED_BY(ctx_mutex_); + tensorflow::TensorHandle* handle; }; struct TFE_Op { @@ -311,7 +95,7 @@ struct TFE_Op { const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - tensorflow::gtl::InlinedVector inputs; + tensorflow::gtl::InlinedVector inputs; tensorflow::Device* device; bool use_xla = false; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 2268aba90d60b7b2f10e99f64fd7aa3ae719badb..701175e4943d1d23532fe595319f67711316ed4d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -590,7 +590,13 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_TensorHandle* m1 = TestMatrixTensorHandle(); TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* matmul2 = MatMulOp(ctx, m1, m1); + TFE_OpSetDevice(matmul2, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1] = {nullptr}; int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); @@ -688,19 +694,19 @@ TEST(CAPI, Execute_Min_CPU) { TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); float output[2] = {0}; EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 9b46cf8245901934c9c4d41a2b7c10c1c5bf7cbd..abe2793ce894ad07c252575c5d55d98342916eac 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -95,22 +96,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, - TF_AttrType* out, unsigned char* is_list) { - auto* t = gtl::FindOrNull(m, attr_name); - if (t == nullptr) { - return errors::InvalidArgument("Attribute '", attr_name, - "' does not exist for this operation"); - } - *out = static_cast(*t & ~kIsList); - if (*t & kIsList) { - *is_list = 1; - } else { - *is_list = 0; - } - return Status::OK(); -} - #define DEFINE_SET_ATTR(value_type, value_field) \ template <> \ AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \ @@ -168,6 +153,22 @@ const NodeDef& AttrBuilder::BuildNodeDef() { return *node_def_; } +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list) { + auto* t = gtl::FindOrNull(m, attr_name); + if (t == nullptr) { + return errors::InvalidArgument("Attribute '", attr_name, + "' does not exist for this operation"); + } + *out = static_cast(*t & ~kIsList); + if (*t & kIsList) { + *is_list = 1; + } else { + *is_list = 0; + } + return Status::OK(); +} + namespace { inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, const tensorflow::Fprint128& b) { @@ -245,104 +246,4 @@ void AttrBuilder::MayBeInitializeNodeDef() { } } -// static -Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = CreateOpKernel(device->device_type().c_str(), device, - device->GetAllocator(AllocatorAttributes()), - nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); - out->device_ = device; - out->kernel_.reset(k); - out->flib_ = nullptr; - return s; -} - -// static -Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = flib->CreateKernel(ndef, &k); - out->device_ = flib->device(); - out->kernel_.reset(k); - out->flib_ = flib; - return s; -} - -Status KernelAndDevice::Run(std::vector* input_tensors, - std::vector* output_tensors, - NodeExecStats* stats) { - gtl::InlinedVector inputs; - for (Tensor& t : *input_tensors) { - inputs.push_back(TensorValue(&t)); - } - - std::vector out_attrs(kernel_->num_outputs()); - for (size_t i = 0; i < out_attrs.size(); ++i) { - out_attrs[i].set_on_host(kernel_->output_memory_types()[i] == - tensorflow::HOST_MEMORY); - } - - OpKernelContext::Params params; - params.device = device_; - params.frame_iter = FrameAndIter(0, 0); - params.inputs = &inputs; - params.op_kernel = kernel_.get(); - params.resource_manager = device_->resource_manager(); - params.output_attr_array = gtl::vector_as_array(&out_attrs); - params.function_library = flib_; - params.slice_reader_cache = &slice_reader_cache_; - params.rendezvous = rendez_; - if (stats != nullptr) { - params.track_allocations = true; - } - // TODO(apassos): use a thread pool. - std::function)> runner = - [](std::function f) { f(); }; - params.runner = &runner; - - OpKernelContext context(¶ms); - - if (kernel_->def().op() == "_Recv") { - // TODO(apassos) do not special-case _Recv. Currently the GPU device fails - // if trying to run _Recv->Compute(), specifically checking for _Recv. To go - // around this we call _Recv->ComputeAsync, to mimic graph mode behavior. - AsyncOpKernel* async = kernel_->AsAsync(); - Notification done; - device_->ComputeAsync(async, &context, [&done]() { done.Notify(); }); - done.WaitForNotification(); - } else { - device_->Compute(kernel_.get(), &context); - } - if (!context.status().ok()) return context.status(); - - output_tensors->clear(); - for (int i = 0; i < context.num_outputs(); ++i) { - output_tensors->push_back(Tensor(*context.mutable_output(i))); - } - if (stats != nullptr) { - for (const auto& allocator_pair : context.wrapped_allocators()) { - AllocatorMemoryUsed* memory = stats->add_memory(); - memory->set_allocator_name(allocator_pair.first->Name()); - auto sizes = allocator_pair.second->GetSizes(); - memory->set_total_bytes(std::get<0>(sizes)); - memory->set_peak_bytes(std::get<1>(sizes)); - memory->set_live_bytes(std::get<2>(sizes)); - - AllocatorStats allocator_stats; - allocator_pair.first->GetStats(&allocator_stats); - memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); - allocator_pair.second->GetRecordsAndUnRef(); - } - auto* ms = stats->mutable_memory_stats(); - ms->set_temp_memory_size(context.temp_memory_allocated()); - for (const auto& alloc_id : context.persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - - ms->set_persistent_memory_size(context.persistent_memory_allocated()); - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index ad16f65495f8a8193b685c2b13a099232d03a505..929b1b8296faf61c11c68af06ffc4ca3770ae929 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -45,6 +46,10 @@ Status OpDefForOp(const char* op_name, const OpDef** op_def); // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); +// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); + // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); @@ -149,53 +154,6 @@ template <> AttrBuilder& AttrBuilder::Set(StringPiece attr_name, tensorflow::DataType&& value); -// KernelAndDevice encapsulates an instantiated kernel and the device it is on. -// -// Also see: -// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h -// and -// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h -class KernelAndDevice { - public: - // Populates 'out' with a kernel appropriate for 'ndef'. - // - // The provided FunctionLibraryRuntime MUST outlive all calls to - // Run() on the returned KernelAndDevice. - // - // TODO(ashankar): Figure out thread-safety concerns around - // FunctionLibraryRuntime (in particular, how the underlying - // FunctionLibraryDefinition might be mutated by another thread as new - // functions are registered with it). Conservatively, thread-safe usage of - // the FunctionLibraryRuntime is pushed on to the caller (see locking in - // c_api.cc). - static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out); - // TODO(ashankar): Remove this - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} - - // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector* inputs, std::vector* outputs, - NodeExecStats* stats); - - const OpKernel* kernel() const { return kernel_.get(); } - - Device* device() const { return device_; } - - DataTypeVector* mutable_output_dtypes() { return &output_dtypes_; } - const DataTypeVector& output_dtypes() { return output_dtypes_; } - - private: - std::unique_ptr kernel_; - Device* device_; - FunctionLibraryRuntime* flib_; - checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; - Rendezvous* rendez_; - DataTypeVector output_dtypes_; -}; } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 4f75d278878d7c8ff6a5e48e5b4e633aa13aedc5..27ebeb0508844ee1ee89e0733b66f6ed129b7757 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -33,27 +33,6 @@ limitations under the License. namespace tensorflow { namespace { -class TestEnv { - public: - TestEnv() : flib_def_(OpRegistry::Global(), {}) { - Device* device = - DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); - device_mgr_.reset(new DeviceMgr({device})); - flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), - device, TF_GRAPH_DEF_VERSION, - &flib_def_, nullptr, {}, nullptr); - } - - FunctionLibraryRuntime* function_library_runtime() const { - return flib_runtime_.get(); - } - - private: - FunctionLibraryDefinition flib_def_; - std::unique_ptr device_mgr_; - std::unique_ptr flib_runtime_; -}; - TEST(AttrTypeMap, Lookup) { const AttrTypeMap* m = nullptr; Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m); @@ -79,113 +58,5 @@ TEST(AttrTypeMap, Lookup) { EXPECT_NE(is_list, 0); } -TEST(KernelAndDevice, Run) { - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - Status s = - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); - ASSERT_TRUE(s.ok()) << s; - std::vector outputs; - s = kernel.Run(&inputs, &outputs, nullptr); - ASSERT_TRUE(s.ok()) << s; - ASSERT_EQ(1, outputs.size()); - const Tensor& out = outputs[0]; - EXPECT_EQ(7, out.matrix()(0, 0)); - EXPECT_EQ(10, out.matrix()(0, 1)); - EXPECT_EQ(15, out.matrix()(1, 0)); - EXPECT_EQ(22, out.matrix()(1, 1)); -} - -void BM_CreateGraph(int iters) { - for (int i = 0; i < iters; ++i) { - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - TF_CHECK_OK(root.status()); - } -} -BENCHMARK(BM_CreateGraph); - -void BM_RunGraph(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - SessionOptions opts; - opts.config.set_inter_op_parallelism_threads(1); - opts.config.set_intra_op_parallelism_threads(1); - ClientSession sess(root, opts); - std::vector outputs; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - outputs.clear(); - TF_CHECK_OK(sess.Run({M}, &outputs)); - } -} -BENCHMARK(BM_RunGraph); - -void BM_CreateAndDestroySession(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - ClientSession sess(root); - } -} -BENCHMARK(BM_CreateAndDestroySession); - -void BM_KernelAndDeviceInit(int iters) { - tensorflow::testing::StopTiming(); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(2) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice k(nullptr); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); - } -} -BENCHMARK(BM_KernelAndDeviceInit); - -void BM_KernelAndDeviceRun(int iters) { - tensorflow::testing::StopTiming(); - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - std::vector outputs; - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); - } -} -BENCHMARK(BM_KernelAndDeviceRun); } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index bdb0815d6b68444ec1c89b835d563db20ce4d8a1..97c323b87228039ba10f4ed5e434aa83621b1220 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -152,6 +152,8 @@ class GradientTape { gtl::ArraySlice output_gradients, std::vector* result); + bool IsPersistent() const { return persistent_; } + private: TensorTape tensor_tape_; OpTape op_tape_; @@ -599,23 +601,28 @@ Status GradientTape::ComputeGradient( } CHECK(state.op_tape.empty()); result->reserve(source_tensor_ids.size()); + gtl::FlatSet used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { auto grad_it = gradients.find(is); if (grad_it == gradients.end()) { result->push_back(nullptr); } else { - if (grad_it->second.size() == 1) { - result->push_back(grad_it->second[0]); - } else { - result->push_back(vspace.AggregateGradients(grad_it->second)); + if (grad_it->second.size() > 1) { + Gradient* grad = vspace.AggregateGradients(grad_it->second); + grad_it->second.clear(); + grad_it->second.push_back(grad); } - gradients.erase(grad_it); + result->push_back(grad_it->second[0]); + used_gradient_ids.insert(is); } } - VLOG(1) << "Final gradients size: " << gradients.size(); + VLOG(1) << "Final gradients size: " + << gradients.size() - used_gradient_ids.size(); for (auto grad_pair : gradients) { - for (const auto& g : grad_pair.second) { - vspace.DeleteGradient(g); + if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } } } return Status::OK(); diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index cd604538f1fa142c6fe6a76624c048baddaa52fb..93155998b86d59ec78c7ff25f146b8e3c8eac380 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/python_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" namespace tensorflow { @@ -109,4 +110,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } +std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + string result; + handle_data.SerializeToString(&result); + return result; +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 13b680b3a24afa2d285ea18207578aff4350f6d5..2d4c8cd9ed7bc926f448dab1f6b50ed74179ea14 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_PYTHON_API_H_ #define TENSORFLOW_C_PYTHON_API_H_ +#include + #include "tensorflow/c/c_api.h" // These functions can be removed without notice. They exist to facilitate some @@ -51,6 +53,11 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); // the graph after the session has been made aware of them. void ExtendSession(TF_Session* session, TF_Status* status); +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource tensor, or otherwise returns the empty string. +// TODO(b/74620627): remove when _USE_C_SHAPES is removed +std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/testdata/tf_record b/tensorflow/c/testdata/tf_record new file mode 100644 index 0000000000000000000000000000000000000000..6e16076bfb79ad8151952e96567565e8820b0f5b Binary files /dev/null and b/tensorflow/c/testdata/tf_record differ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 9060c19e9d2cf965c2b9be07be07c42017da45a8..079e063d3e3fbdaf833e9031f5f9438853c14099 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -620,18 +620,6 @@ tf_cc_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "queue_runner", srcs = ["training/queue_runner.cc"], diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 1e0f2d241bb350897a840dda90d6d0c009b1daad..5d9dfd95a5538ae0f3d2d111a1f989552c3363b8 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -61,12 +62,12 @@ op { )"; void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(s.contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { - EXPECT_FALSE(s.contains(expected)) + EXPECT_FALSE(str_util::StrContains(s, expected)) << "'" << s << "' contains '" << expected << "'"; } diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 71642492627422e09c19b7bcb4dc522846cf08b1..c143b978338815ebc7134eb0a07867c5d8b13dca 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -218,7 +219,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); - if (s.Consume(kColocationGroupPrefix)) { + if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { current_constraints.insert(s.ToString()); } } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index d29ad3ebcbe29087d5572b51c7713e0c98d0d840..06a3be18e08f611d3ecf9804908d791d15fdab13 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -94,18 +94,3 @@ filegroup( "testdata/half_plus_two/**", ]), ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2..72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -133,9 +134,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied " - "tags: { missing-tag }")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); } @@ -149,9 +150,9 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags: ")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) << st.error_message(); } @@ -169,7 +170,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD index f5fbc75edcba9d5ae9ef7432de224df766bcab9e..6f04ebdc55cda329527c95f62efc37c8dfbb4ae5 100644 --- a/tensorflow/cc/saved_model/python/BUILD +++ b/tensorflow/cc/saved_model/python/BUILD @@ -7,18 +7,6 @@ package( default_visibility = ["//visibility:public"], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") tf_py_clif_cc( diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index f413a5cc52e9eb4bc393b8186f5b591681fa2e5e..6f1c87354076565af22f7ba0610a5c6bb999d25c 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -41,18 +41,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index 3675d72ee354533a7d84b5e8783cde452d8d60c9..5dbc4f5f6aa389978e55ca2656c17ff97202203d 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -166,7 +167,8 @@ namespace { bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, int32* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag) && + tensorflow::str_util::ConsumePrefix(&arg, "=")) { char extra; return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); } @@ -176,7 +178,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool* dst) { - if (arg.Consume(flag)) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *dst = true; return true; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index ffa2d088295375bbbcd2cdd9365982907f2bf480..fa03b1f3c2dfc334d4a3871e6a1bf5503fa8d5f8 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -250,17 +250,3 @@ exports_files([ "benchmark_main.template", # used by tf_library(...,gen_benchmark=True) "test.cc", # used by tf_library(...,gen_test=True) ]) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 972b7d51ecb3798e61757ac55e973075a23b433a..2642536c4f67eba8eedf315f24d800e7913d62a0 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 28aab6eb614ca7123d9e00f7f5cc3661b62e23f7..b053dad1b57c258b7cb0d6831923e6a0f30f5e7e 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -182,17 +182,3 @@ tf_cc_test( "//third_party/eigen3", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 9dff1be09fede6f65f82c2f36d94be07e781949f..3a877c5337ff76193a7f27fb9681e5a9ca500961 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -132,7 +132,7 @@ def tf_library(name, graph, config, header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" - ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index e2f01179d4e2e4f6ef72b2761d06e130ffa3a94f..8ea014c2eede2cb7a9cede9dd4ade8b970bd519c 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -55,7 +55,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (StringPiece(fname).ends_with(".pbtxt")) { + if (str_util::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c4a2d4ab0321bbf9db91f5e4387084c27e576b87..24aa203c00b3a011ae11007e308f8bbb6998204e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -76,6 +76,7 @@ cc_library( ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -118,14 +119,33 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_tensor", + srcs = ["xla_tensor.cc"], + hdrs = ["xla_tensor.h"], + deps = [ + ":common", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "xla_device", srcs = [ + "xla_compile_on_demand_op.cc", "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", ], hdrs = [ + "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", @@ -136,6 +156,7 @@ cc_library( ":common", ":jit_compilation_passes", ":xla_launch_util", + ":xla_tensor", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", @@ -182,6 +203,7 @@ cc_library( deps = [ ":common", ":xla_compilation_cache", + ":xla_tensor", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -328,6 +350,7 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -338,26 +361,13 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 2d175c40f9dfaef4e5024b77a6ecb8d6022e7a56..b04b333141a616e7c4db2751c14ec6eb0b7725b5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -53,6 +53,8 @@ namespace tensorflow { const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; namespace { @@ -143,7 +145,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kHostComputeOp = "_XlaHostCompute"; +static const char* const kHostComputeOp = "XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -252,7 +254,8 @@ class Encapsulator { // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out); @@ -328,12 +331,14 @@ class Encapsulator { Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to - // all the downstream nodes of call_node_outputs. - void ConnectSequencerToOutputs(Graph* graph_out); + // the call node. + void ConnectSequencerToCallNode(Graph* graph_out); Status AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph); + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library); Status ReplaceFunctionDef(FunctionLibraryDefinition* library); @@ -401,7 +406,9 @@ class Encapsulator { // Builds a _RecvAtHost node producing all the inputs of an // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. - Status AddRecvAtHostNode(const string& subgraph_name, + Status AddRecvAtHostNode(const string& group_attribute, + const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); @@ -410,8 +417,10 @@ class Encapsulator { // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. Status AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, + const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are @@ -425,6 +434,10 @@ class Encapsulator { // NodeDef for the function call node. NodeDef call_node_def_; + // Name that is used for the call node. This may not be + // call_node_def_.name() if the client supplies a rewrite lambda. + string function_def_name_; + // Placeholder node simulating the host compute key in the output graph. // Not owned. Node* host_compute_key_placeholder_ = nullptr; @@ -567,7 +580,7 @@ class Encapsulator { const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out); + std::unique_ptr* graph_out); // Makes a copy of graph containing only nodes that are ancestors of at least // one node in send_from_host_nodes and store it in pruned_graph. On exit @@ -812,6 +825,7 @@ Status Encapsulator::Subgraph::AddHostComputes( builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -863,25 +877,21 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, NodeDef seq_def; NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), "NoOp"); + builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); + builder.Device(device_); Status s = builder.Finalize(&seq_def); if (!s.ok()) return s; sequencer_ = graph_out->AddNode(seq_def, &s); if (!s.ok()) return s; - sequencer_->set_assigned_device_name(device_); } return Status::OK(); } -void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { +void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { - std::unordered_set output_dependencies; - for (Node* node : call_node_outputs_->out_nodes()) { - output_dependencies.insert(node); - } - for (Node* node : output_dependencies) { - graph_out->AddControlEdge(sequencer_, node); - } + VLOG(2) << "ConnectSequencerToCallNode"; + graph_out->AddControlEdge(sequencer_, call_node_inputs_); } } @@ -927,6 +937,8 @@ Status Encapsulator::Subgraph::BuildFunctionDef( name = call_node_def_.op(); } + function_def_name_ = name; + FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -945,8 +957,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( } Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph) { + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library) { OutsideCompilationSubgraph& oc_subgraph = outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); @@ -968,21 +982,22 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shape_inference_graph", ""); host_compute->AddAttr("shapes", shapes); } else { - string serialized_graph; - if (!inference_graph->SerializeToString(&serialized_graph)) { - return errors::Internal( - "Failed to serialize graph for outside compilation subgraph ", - oc_subgraph.host_compute_name); - } - host_compute->AddAttr("shape_inference_graph", serialized_graph); + string inference_graph_name = + strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); + FunctionDef fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); + host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = call_node_def_.name(); + const string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -1105,7 +1120,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( } Status Encapsulator::Subgraph::AddRecvAtHostNode( - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { if (host_compute_key_placeholder_ == nullptr) { TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); @@ -1128,17 +1144,19 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); - // TODO(misard) For now we only support TPU device 0. + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. builder.Attr("device_ordinal", 0); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); if (!s.ok()) return s; - oc_subgraph->recv_at_host->set_assigned_device_name(device_); graph_out->AddEdge(host_compute_key_placeholder_, 0, oc_subgraph->recv_at_host, 0); @@ -1153,7 +1171,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( Status Encapsulator::Subgraph::AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { if (host_compute_key_placeholder_ == nullptr) { TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); @@ -1182,8 +1201,11 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( builder.Attr("Tinputs", dtypes); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); - // TODO(misard) For now we only support TPU device 0. + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. builder.Attr("device_ordinal", 0); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(inputs); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&send_def); @@ -1191,7 +1213,6 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); if (!s.ok()) return s; - oc_subgraph->send_from_host->set_assigned_device_name(device_); graph_out->AddEdge(host_compute_key_placeholder_, 0, oc_subgraph->send_from_host, inputs.size()); @@ -1205,7 +1226,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out) { for (auto& outside_compilation_subgraph_entry : @@ -1215,14 +1237,16 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( outside_compilation_subgraph_entry.second; if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - TF_RETURN_IF_ERROR( - AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name, + outside_compilation_attribute, + oc_name, &oc_subgraph, graph_out)); } if (!oc_subgraph.outputs_by_src.empty() || !oc_subgraph.control_outputs.empty()) { - TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, - oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddSendFromHostNode( + node_images, group_attribute, subgraph_name, + outside_compilation_attribute, oc_name, &oc_subgraph, graph_out)); } } return Status::OK(); @@ -1439,8 +1463,6 @@ Status Encapsulator::CopyNodesToOutputGraph( "Parallel checking is not supported when outside_compilation " "clusters are present."); } - image->ClearAttr(group_attribute_); - image->ClearAttr(outside_compilation_attribute_); } (*node_images)[node] = image; } @@ -1466,7 +1488,8 @@ Status Encapsulator::AddOutsideCompilationHostIONodes( const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( - subgraph_name, node_images, graph_out)); + group_attribute_, subgraph_name, outside_compilation_attribute_, + node_images, graph_out)); } return Status::OK(); } @@ -1675,7 +1698,7 @@ Status Encapsulator::AddEdgesToOutputGraph( for (auto& subgraph_entry : subgraphs_) { Subgraph& subgraph = subgraph_entry.second; - subgraph.ConnectSequencerToOutputs(graph_out); + subgraph.ConnectSequencerToCallNode(graph_out); } return Status::OK(); @@ -1754,7 +1777,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out) { + std::unique_ptr* graph_out) { // Maps from nodes in graph_in to nodes in graph_out. // // When an edge has fully defined shape the source node in graph_in is @@ -1771,8 +1794,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( std::unordered_map dummy_node_images; std::unordered_map copied_node_images; - std::unique_ptr graph_out(new Graph(graph_in.op_registry())); - graph_out->set_versions(graph_in.versions()); + graph_out->reset(new Graph(graph_in.op_registry())); + (*graph_out)->set_versions(graph_in.versions()); // The final input to the send node is the dynamic key, which we don't include // in the static shapes. static_shape_out->resize(send_node->num_inputs() - 1); @@ -1794,7 +1817,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (w.leave) { TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( n, send_node, dummy_node_images, library, &copied_node_images, - graph_out.get())); + graph_out->get())); } else { if (visited[n->id()]) continue; visited[n->id()] = true; @@ -1818,7 +1841,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( context->ShapeHandleToProto(shape, &proto); if (dummy_node_images.find(src_node) == dummy_node_images.end()) { dummy_node_images[src_node] = AddDummyShapedNode( - src_node->output_type(src_port), proto, graph_out.get()); + src_node->output_type(src_port), proto, graph_out->get()); } // The final input to the send node is the dynamic key, which we // don't include in the static shapes. @@ -1827,8 +1850,12 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( (*static_shape_out)[in_edge->dst_input()] = proto; } } else { + has_parent_with_unknown_shape = true; if (!visited[src_node->id()]) { - has_parent_with_unknown_shape = true; + if (VLOG_IS_ON(2)) { + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + } stack.push_back({src_node, false}); } } @@ -1839,7 +1866,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // The shapes of all the inputs to send_node are statically known. We // won't have to do any inference at compile time so return now: the // shapes were stored in static_shape_out above. - graphdef_out->reset(); + graph_out->reset(); return Status::OK(); } else { // Any shape that is being processed is either the original send node @@ -1862,9 +1889,6 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( } } - graphdef_out->reset(new GraphDef()); - graph_out->ToGraphDef(graphdef_out->get()); - return Status::OK(); } @@ -1981,14 +2005,20 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", + *pruned_graph, library); + } + for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; // Find all the recv_at_host nodes in this subgraph. std::vector outside_compilation_names; subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); std::unordered_set recv_at_host_names; - for (const auto& name : outside_compilation_names) { - Node* recv_node = subgraph.GetRecvAtHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(oc_name); if (recv_node != nullptr) { recv_at_host_names.insert(recv_node->name()); } @@ -1997,26 +2027,30 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( // without knowing the shape of the recv_at_host nodes, and store the // result, along with enough information to complete the job at compile time // once the recv_at_host shapes are known. - for (const auto& name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(oc_name); std::vector static_shape; - std::unique_ptr graphdef; + std::unique_ptr graph; if (send_node != nullptr) { TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( *pruned_graph, shape_refiner, recv_at_host_names, - node_images[send_node], library, &static_shape, &graphdef)); - if (graphdef == nullptr) { + node_images[send_node], library, &static_shape, &graph)); + if (graph == nullptr) { VLOG(2) << "Send node " << send_node->name() << " shapes"; for (int i = 0; i < static_shape.size(); ++i) { VLOG(2) << static_shape[i].DebugString(); } } else { - VLOG(2) << "Send node " << send_node->name() << " graph\n" - << graphdef->DebugString(); + if (VLOG_IS_ON(2)) { + GraphDef graphdef; + graph->ToGraphDef(&graphdef); + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef.DebugString(); + } } } - TF_RETURN_IF_ERROR( - subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo( + subgraph_name, oc_name, static_shape, graph.get(), library)); } if (!outside_compilation_names.empty()) { TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index d7bea56a7244665c571c23d49c6769a163b86e9e..8599a7038af9663e5af6f3231429cb7f6ea5f69b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,22 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; + +Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, + const string& name_suffix, + FunctionDefLibrary* library) { + GraphDef graphdef; + TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); + std::unique_ptr graph = + std::unique_ptr(new Graph(OpRegistry::Global())); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get())); + FunctionDef* fdef = library->add_function(); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *graph, + strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + fdef)); + return Status::OK(); +} + template bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const ::tensorflow::protobuf::Map& b, @@ -112,23 +136,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, a.attr(), b.attr(), [](const string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, [](const string& key, const AttrValue& av, const AttrValue& bv) { - if (key == "shape_inference_graph") { - // Default serialization of GraphDef is unstable because maps don't - // serialize deterministically. Rather than go through the hoops to - // turn on deterministic serialization of this attr just for this - // test, add logic here to compare determinstically. - GraphDef ga; - if (!ga.ParseFromString(av.s())) { - return false; - } - GraphDef gb; - if (!gb.ParseFromString(bv.s())) { - return false; - } - return EqualGraphDef(ga, gb, nullptr); - } else { - return av.DebugString() == bv.DebugString(); - } + return av.DebugString() == bv.DebugString(); }, strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); @@ -248,7 +256,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, // These dummy Op registrations are here because the real Op registrations live // in contrib and there can't be a dependence from this test to contrib. -REGISTER_OP("_XlaHostCompute") +REGISTER_OP("XlaHostCompute") .Input("inputs: Tinputs") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") @@ -321,8 +329,13 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); -Node* NoOp(const GraphDefBuilder::Options& opts) { - return ops::SourceOp("NoOp", opts); +Node* Sequencer(const GraphDefBuilder::Options& opts, + const string& call_node_name) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", + opts.op_registry()); + return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name) + .FinalizeBuilder(&node_builder); } Node* Input(const GraphDefBuilder::Options& opts) { @@ -370,24 +383,36 @@ Node* KeyPlaceholder(const string& call_node, .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(ops::NodeOut key_input, const string& key, +Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, const gtl::ArraySlice& dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_recv"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); return opts.WithAttr("Toutputs", dtypes) .WithAttr("key", key) .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } -Node* SendFromHost(ops::NodeOut key_input, const string& key, +Node* SendFromHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_send"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); node_builder.Input(std::move(key_input)); @@ -398,6 +423,8 @@ Node* SendFromHost(ops::NodeOut key_input, const string& key, return opts.WithAttr("Tinputs", dtypes) .WithAttr("key", key) .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } @@ -745,7 +772,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -790,7 +817,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -840,22 +867,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -870,13 +895,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -888,28 +915,29 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost(ops::NodeOut(key_constant, 0), - "host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = + b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -959,45 +987,43 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected_1; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape1.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape1.opts().WithName("E")); - SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {e}, shape1.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape1_graph; - TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); - EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } - string shape_string_expected_2; { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - shape2.opts().WithName("E")); - Node* recv2 = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", - {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O2_recv")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); - SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", - {h}, shape2.opts().WithName("outside_compilation_F1_O2_send")); - GraphDef shape2_graph; - TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); - EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + shape2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* h = Binary(ops::NodeOut(recv2, 0), e, + shape2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } *library_expected.add_function() = FunctionDefHelper::Create( @@ -1014,22 +1040,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0", "F:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", shape_string_expected_2}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected_1}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1041,40 +1071,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), - "host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* recv2 = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", - {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O2_recv")); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* g = Binary(e, ops::NodeOut(recv2, 1), - b2.opts().WithName("G").WithControlInputs({recv2, e})); - Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); - Node* send2 = SendFromHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O2", {h}, - b2.opts().WithName("outside_compilation_F1_O2_send")); + b2.opts() + .WithName("G") + .WithControlInputs({recv2, e}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); - Node* s = NoOp(b2.opts() - .WithName("F1_sequencer") - .WithControlInputs({recv1, send1, recv2, send2})); + Node* s = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + Binary(g, call, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1123,22 +1158,20 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } TensorShapeProto shape_proto_expected; @@ -1156,13 +1189,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -1176,14 +1211,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"G:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1196,43 +1232,46 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant1, 0), "host_compute_channel_F1_O1", - {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), - "host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost( - ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", - {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", + {DT_FLOAT}, b2.opts()); Node* h = Binary(ops::NodeOut(call1, 1), recv2, - b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = SendFromHost( - ops::NodeOut(key_constant2, 0), "host_compute_channel_F2_O1", {h}, - b2.opts().WithName("outside_compilation_F2_O1_send")); + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts()); + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); Node* call2 = b2.opts() - .WithControlInputs({s1, e, call1}) + .WithControlInputs({s2, e, call1}) .FinalizeBuilder(&node_builder2); - Node* s2 = NoOp( - b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); - Binary(call2, ops::NodeOut(call2, 1), - b2.opts().WithName("J").WithControlInput(s2)); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1280,14 +1319,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1298,18 +1338,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* send1 = SendFromHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1358,14 +1402,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1380,19 +1425,23 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {}, b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); - Node* send1 = SendFromHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1434,13 +1483,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1453,16 +1503,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1509,13 +1563,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1528,22 +1583,23 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), - "host_compute_channel_F1_O1", {}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1594,7 +1650,10 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1640,21 +1699,21 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1")); - Node* recv = RecvAtHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); - SendFromHost(ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", - {e}, shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape.opts()); + Node* e = BinaryUnknownShape(known, recv, + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -1668,13 +1727,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1687,29 +1748,29 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* b = Input(b2.opts().WithName("B")); Node* c = Unary(a, b2.opts().WithName("C")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); - Node* key_constant = - KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost( - ops::NodeOut(key_constant, 0), "host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape( - c, ops::NodeOut(recv, 0), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost(ops::NodeOut(key_constant, 0), - "host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); - - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index 6fa21fa6204dcc9446081d07e2a59ccace216713..8f5e11dfa47956f1fdaa4d1ff115affa375c5c73 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -229,7 +230,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) { NodeDef* node_def = fdef->mutable_node_def(n_index); for (int i = 0; i < node_def->input_size(); ++i) { - if (StringPiece(node_def->input(i)).starts_with("^")) { + if (str_util::StartsWith(node_def->input(i), "^")) { // Control input const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 15507b3851751c681044a744c07c247410fb3e2d..676f71a75aede2a7720ae0c8a579d64cc184509a 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -27,17 +27,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 616a7f8f1541d3debff97a90bd390c76c665d196..00a6f4075f9a18efc3895b033eb6d08e36088a53 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -41,17 +41,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index cd7f8dd779120637c96d6af041b0afcc734e5eff..2d6511a45b9b37df8405d34dd2aec5ba31254c16 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -114,10 +114,12 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); + const XlaDevice::Metadata* metadata; + Status s = XlaDevice::GetMetadata(ctx, &metadata); + bool allocate_xla_tensors = s.ok(); + // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { platform_id_ = metadata->platform()->id(); } @@ -128,8 +130,23 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::LocalClient* client = static_cast(cache->client()); - // Builds an XLA allocator for the device. - XlaAllocator xla_allocator(client->platform(), ctx); + XlaAllocator local_xla_allocator(client->backend().platform(), + ctx->device()->GetAllocator({})); + xla::DeviceMemoryAllocator* xla_allocator; + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which local_xla_allocator above uses) as on an XlaDevice, this is a + // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a + // real allocator to allocate real buffers. + if (allocate_xla_tensors) { + xla_allocator = client->backend().memory_allocator(); + } else { + xla_allocator = &local_xla_allocator; + } XlaCompiler::Options options; options.client = client; @@ -137,26 +154,30 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.device_allocator = &xla_allocator; + options.device_allocator = xla_allocator; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + std::map constant_args; + for (int i = 0; i < num_constant_args_; ++i) { + constant_args.insert({i, ctx->input(i)}); + } + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context(num_resource_args_, client, - &xla_allocator); + XlaComputationLaunchContext launch_context( + num_resource_args_, client, xla_allocator, allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(&xla_allocator); + run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -166,8 +187,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - launch_context.PopulateOutputs(ctx, kernel, - run_result.ConsumeValueOrDie()->release()); + launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 4491dd6ac8f2b84f341162eb469cc8194f817c9a..5d211f4d733d8d807426e62dd116092799184f35 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -52,16 +52,14 @@ cc_library( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", +cc_library( + name = "xla_device_flags", + srcs = ["xla_device_flags.cc"], + hdrs = ["xla_device_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], ) diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 51384ac2fe6fa70c8a723097093a0a29e7ad2c6b..7277a1d1f8ad5fa045645ead839ab9efa01e89c7 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -41,6 +41,7 @@ static void AllocateFlags() { flags->tf_xla_clustering_debug = false; flags->tf_xla_cpu_global_jit = false; flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + flags->tf_xla_fusion_only = false; flag_list = new std::vector( {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, "Control compilation of operators into XLA computations on CPU and " @@ -59,7 +60,10 @@ static void AllocateFlags() { "Enables global JIT compilation for CPU via SessionOptions."), Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, "Places an artificial limit on the number of ops marked as " - "eligible for clustering.")}); + "eligible for clustering."), + Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 170b89c987f30f985f981d7835b4af455922594e..2affda6ab4e0fbad32a246744fa5b38aeb629c1b 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -51,6 +51,10 @@ typedef struct { int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this // many ops will be marked as eligible for // clustering. + bool tf_xla_fusion_only; // This flag is effective only when global_jit_level + // is set to ON* and overrides its behavior. If + // true, enable fusion of element-wise operations + // only using XLA. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bb2fce2dbad5bffce2e33b665b7222090d0855a --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's xla_device module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static XlaDeviceFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new XlaDeviceFlags; + flags->tf_xla_compile_on_demand = false; + flag_list = new std::vector({ + Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..27b22121ac1e089bd5d5a494e1e3fb60b05bc76d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ + +// Legacy flags for the XLA bridge's xla_device module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// The values of flags associated with the XLA bridge's +// xla_device module. +typedef struct { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +} XlaDeviceFlags; + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 57fb8d242208318a608b1f356bef7a8d39dbdc83..f651768a67278628e40445291d7fb271bb1ae611 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -180,6 +180,158 @@ struct NodeCompare { }; using OrderedNodeSet = std::set; +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +// +// TODO(hpucha): Consider a black list instead of a white list as +// implemented below. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/batchtospace_op.cc + "BatchToSpace", "BatchToSpaceND", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/categorical_op.cc + "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/, + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/cross_op.cc + "Cross", + // tf2xla/kernels/depthtospace_op.cc + "DepthToSpace", + // tf2xla/kernels/diag_op.cc + "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart", + // tf2xla/kernels/dynamic_stitch_op.cc + "DynamicStitch", "ParallelDynamicStitch", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fake_quantize_ops.cc + "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/, + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/gather_op.cc + "Gather", "GatherV2", "GatherNd", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", "StopGradient", + "Snapshot", + // tf2xla/kernels/image_ops.cc + "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/, + "AdjustSaturation", "AdjustHue", + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/l2loss_op.cc + "L2Loss" /*(Reduce)*/, + // tf2xla/kernels/lrn_ops.cc (ReduceWindow) + "LRN", "LRNGrad", + // tf2xla/kernels/matrix_band_part_op.cc + "MatrixBandPart", + // tf2xla/kernels/matrix_set_diag_op.cc + "MatrixSetDiag", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/no_op.cc + "NoOp", "ControlTrigger", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/pooling_ops.cc + "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool", + "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/ + "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad", + "AvgPool3DGrad", + // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce) + "QuantizeAndDequantizeV2", + // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend + // currently) + "RandomUniform", "RandomUniformInt", "RandomStandardNormal", + "TruncatedNormal", + // tf2xla/kernels/reduction_ops.cc (Reduce) + "Sum", "Prod", "Min", "Max", "Mean", "All", "Any", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/scan_ops.cc (ReduceWindow) + "Cumsum", "Cumprod", + // tf2xla/kernels/scatter_nd_op.cc (Reduce) + "ScatterNd", + // tf2xla/kernels/segment_reduction_ops.cc (Reduce) + "UnsortedSegmentSum", + // tf2xla/kernels/select_op.cc + "Select", + // tf2xla/kernels/sequence_ops.cc + "Range", "LinSpace", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/softmax_op.cc (Reduce) + "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + // tf2xla/kernels/spacetobatch_op.cc + "SpaceToBatchND", "SpaceToBatch", + // tf2xla/kernels/spacetodepth_op.cc + "SpaceToDepth", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/stack_ops.cc + "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2", + // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on + // GPU + // backend currently) + "StatelessRandomUniform", + "StatelessRandomNormal" + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", + "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/training_ops.cc + "ResourceApplyGradientDescent", "ResourceApplyMomentum", + "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp", + "ResourceApplyFtrl", "ResourceApplyFtrlV2", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, @@ -338,10 +490,13 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + bool fusion_only = flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -364,6 +519,11 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // Check for fusable ops only if requested. + if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + return false; + } + // Otherwise use the value of global_jit_level. // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a8858cccef623185709ab5dc2187a313dd130f7..2e362e0a63f16e4837e63f194920c3f585dd8a46 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -137,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -157,6 +158,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Tensor t(DT_HALF, TensorShape()); + t.scalar()() = static_cast(0.0f); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", t)); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -519,11 +541,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.ToString()) - .contains("Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(str_util::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index e5787ca4c8cff436e4404b8488970248b24a5eda..c9e46bc1475aed0e35a48765ad70eef4362e8281 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -17,17 +17,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 8cc79a9bd0b7aa2098ce177a9d7749f4e6c6ac27..6430975335f5eef5b53c80213e6090ffd6166a91 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -92,39 +92,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( } Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, int num_constant_args, + const NameAttrList& function, const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, Signature* signature) { signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.resize(num_constant_args); - - signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); - - // Inputs are in the order: constants, non-constants, resource variables. - int input_num = 0; - // Use the values of compile time constants in the signature-> - while (input_num < num_constant_args) { - signature->arg_values[input_num] = ctx->input(input_num); - ++input_num; - } - // Add the types and shapes of the remaining arguments. - while (input_num < ctx->num_inputs() - variable_args.size()) { - signature->arg_types.emplace_back(ctx->input_dtype(input_num), - ctx->input(input_num).shape()); - ++input_num; - } - // For variable signatures, use the type and shape of the variable's - // current value. - for (auto& iterator : variable_args) { - const OptionalTensor& variable = iterator.second; - TF_RET_CHECK(input_num < ctx->num_inputs()); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); + signature->arg_values.reserve(constant_args.size()); + + signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + if (constant_args.count(i) > 0) { + // Use the values of compile time constants in the signature. + signature->arg_values.push_back(constant_args.at(i)); + } else if (variable_args.count(i) > 0) { + const OptionalTensor& variable = variable_args.at(i); + if (variable.present) { + signature->arg_types.emplace_back(variable.value.dtype(), + variable.value.shape()); + } else { + signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + } } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + signature->arg_types.emplace_back(ctx->input_dtype(i), + ctx->input(i).shape()); } - ++input_num; } return Status::OK(); } @@ -132,74 +123,58 @@ Status XlaCompilationCache::BuildSignature( namespace { // Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. The first `num_constant_args` arguments must be host-memory Tensors. -Status BuildArguments(int num_constant_args, +// op. +Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); - int input_num = 0; - - // Handles compile-time constants. - TF_RET_CHECK(num_constant_args <= ctx->num_inputs()); - while (input_num < num_constant_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - XlaCompiler::Argument& arg = (*args)[input_num]; - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - ++input_num; - } - - // Handles the non-constant arguments. - int num_variable_args = variable_args.size(); - int num_nonconst_args = - ctx->num_inputs() - num_variable_args - num_constant_args; - TF_RET_CHECK(num_nonconst_args >= 0); - while (input_num < num_constant_args + num_nonconst_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - ++input_num; - } - - // Handles resource variables. - TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs()); - for (auto& iterator : variable_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - - XlaCompiler::Argument& arg = (*args)[input_num]; - - arg.name = iterator.second.name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (iterator.second.present) { - const Tensor& value = iterator.second.value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do permit - // uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } } - ++input_num; } return Status::OK(); @@ -234,16 +209,43 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - int num_constant_args, const std::map& variable_args, - OpKernelContext* ctx, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options) { + return CompileImpl(options, function, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, false); +} + +Status XlaCompilationCache::CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + return CompileImpl(options, name, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, true); +} + +Status XlaCompilationCache::CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << num_constant_args + << " num_constant_args=" << constant_args.size() << " num_variable_args=" << variable_args.size(); for (int i = 0; i < ctx->num_inputs(); i++) { TensorShape shape = ctx->input(i).shape(); @@ -264,11 +266,12 @@ Status XlaCompilationCache::Compile( } } - TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs()); + TF_RET_CHECK(constant_args.size() + variable_args.size() <= + ctx->num_inputs()); Signature signature; - TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args, - ctx, &signature)); + TF_RETURN_IF_ERROR( + BuildSignature(function, constant_args, variable_args, ctx, &signature)); VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not @@ -295,13 +298,20 @@ Status XlaCompilationCache::Compile( // a long time.) std::vector args; TF_RETURN_IF_ERROR( - BuildArguments(num_constant_args, variable_args, ctx, &args)); + BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + + if (compile_single_op) { + entry->compilation_status = compiler.CompileSingleOp( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + signature.name, ctx, args, &entry->compilation_result); + } else { + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); + } } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index d5063783140205db54e673a7c7fd8f94b8aa2c65..5c0c79b880c474969464f23b4485734c404cef07 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -52,8 +52,8 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `num_constant_args` is the number of compile-time constant arguments to - // `function`. `variable_args` is a snapshot of the current values of the + // `constant_args` is a maps of tensorflow argument number to constant value. + // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. // The result of compilation is written to `*compilation_result`, which must @@ -62,19 +62,40 @@ class XlaCompilationCache : public ResourceBase { // executable pointer may be null if the computation has no non-constant // outputs. Status Compile(const XlaCompiler::Options& options, - const NameAttrList& function, int num_constant_args, + const NameAttrList& function, + const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options); + // As above, but calls XlaCompiler::CompileSingleOp instead of + // XlaCompiler::CompileFunction. + Status CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); + xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl(const XlaCompiler::Options& options, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. Status BuildExecutable(const XlaCompiler::Options& options, @@ -104,7 +125,8 @@ class XlaCompilationCache : public ResourceBase { static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, int num_constant_args, + Status BuildSignature(const NameAttrList& function, + const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, Signature* signature); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..682d6ea8ccc4a54912ccad4666cf0a7a03a7a698 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -0,0 +1,175 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the XlaCompileOnDemandOp. + +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +namespace { +std::map GetVariables(OpKernelContext* ctx) { + std::map variables; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dtype() == DT_RESOURCE) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& optional = variables[i]; + optional.name = handle.name(); + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + optional.present = true; + optional.value = *variable->tensor(); + } + } + } + return variables; +} +} // namespace + +Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, + const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable) { + std::map variables = GetVariables(ctx); + int64 num_resource_args = variables.size(); + + xla::LocalClient* client = metadata.client(); + + // Builds an XLA allocator for the device. + XlaComputationLaunchContext launch_context( + num_resource_args, client, client->backend().memory_allocator(), true); + + launch_context.PopulateInputs(ctx, result, variables); + + perftools::gputools::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + TF_RET_CHECK(stream); + + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(client->backend().memory_allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + TF_RETURN_IF_ERROR(run_result.status()); + + launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + return Status::OK(); +} + +bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // TODO(jmolloy): This could be expensive, so memoize. + auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + op_kernel->def().op()); + CHECK(constant_inputs); + std::set constant_input_indices; + for (const auto& name : *constant_inputs) { + int start, stop; + TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop)); + for (int i = start; i < stop; ++i) { + constant_input_indices.insert(i); + } + } + return constant_input_indices.count(argument_idx) > 0; +} + +bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // Right now we only create kConstant arguments when absolutely required, but + // there may be benefit in eagerly constant-folding a larger subset of + // arguments in the future. + return MustArgumentBeConstant(op_kernel, argument_idx); +} + +Status XlaCompileOnDemandOp::Compile( + OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable) { + std::map constant_arguments; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& device_tensor = ctx->input(i); + if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { + if (xla_tensor->has_host_tensor() && + ShouldArgumentBeConstant(&ctx->op_kernel(), i)) { + constant_arguments[i] = xla_tensor->host_tensor(); + } + } + if (constant_arguments.count(i) == 0 && + MustArgumentBeConstant(&ctx->op_kernel(), i)) { + // Slow path; the argument is not available as a host constant so we must + // fetch it synchronously. + Tensor host_tensor; + AllocatorAttributes attrs; + attrs.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp( + device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); + Notification n; + ctx->op_device_context()->CopyDeviceTensorToCPU( + &device_tensor, "ConstantArgument", + reinterpret_cast(ctx->device()), &host_tensor, + [&](Status status) { n.Notify(); }); + n.WaitForNotification(); + constant_arguments[i] = host_tensor; + } + } + + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + CHECK(rm); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + *cache = new XlaCompilationCache(metadata.client(), + metadata.jit_device_type()); + return Status::OK(); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + XlaCompiler::Options options; + DeviceType device_type = metadata.jit_device_type(); + options.device_type = &device_type; + options.client = metadata.client(); + options.flib_def = + new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + + std::map variable_args = GetVariables(ctx); + return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, + result, executable, + /*compile_options=*/nullptr); +} + +void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { + const XlaCompiler::CompilationResult* result; + xla::LocalExecutable* executable; + const XlaDevice::Metadata* metadata; + OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); + OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable)); + OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23c6f3903f841a6c39104983c6f7f409757a7319 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is +// called, will generate an xla::Computation and run it asynchronously. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// An OpKernel that compiles an op to an XLA computation and runs it. Unlike +// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// vanilla TensorFlow op as long as the bridge supports it. +// +// Importantly _XlaLaunch assumes all input and output tensors are on the host, +// whereas XlacompileOnDemandOp works with tensors in device memory. +class XlaCompileOnDemandOp : public OpKernel { + public: + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override; + + private: + XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); + bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); + Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e238252751e677eb947f6df03e3b2f2e948ffe19..bc07dbd7bdf005fde781f7a1e6775080e363abfb 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,6 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -34,14 +36,24 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); + bool compile_on_demand = flags->tf_xla_compile_on_demand; + + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_CPU_XLA_JIT; + registration.requires_compilation = !compile_on_demand; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } @@ -50,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index d4d8fe1c1d575b4e35d624621cc709e3a16569d5..12f471735f68394a3079541e9ac8532e329bd694 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -99,7 +100,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(backend, device_ordinal); + xla::MakeUnique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -108,21 +109,15 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, bool register_device_for_compilation, - std::unique_ptr* device) { + const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - if (register_device_for_compilation) { - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); - } + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -137,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie())); + platform.ValueOrDie(), transfer_as_literal)); return Status::OK(); } @@ -162,6 +157,7 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { + *metadata = nullptr; XlaDevice* xla_device = dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { @@ -177,13 +173,15 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform) + const DeviceType& jit_device_name, se::Platform* platform, + bool transfer_as_literal) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), - platform_(platform) {} + platform_(platform), + transfer_as_literal_(transfer_as_literal) {} XlaDevice::~XlaDevice() {} @@ -225,7 +223,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - auto ctx = new XlaDeviceContext(stream); + // Call GetAllocator for the side-effect of ensuring the allocator and + // XlaTensorInfoManager is created. + (void)GetAllocator({}); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -273,7 +274,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream); + XlaTransferManager manager(stream, client(), transfer_as_literal_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; @@ -288,19 +289,23 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { + // Any op assigned to the device that isn't rewritten by the graph rewriter + // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes + // it just-in-time. + kernel_factory::OpKernelRegistrar::Factory factory = + [](OpKernelConstruction* context) -> OpKernel* { + return new XlaCompileOnDemandOp(context); + }; XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; - auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { - return new XlaDeviceDummyOp(context); - }; for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( - new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", - dummy_factory)); + new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", + factory)); } return registrations; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d2ec38293c429f04f088bf3726ba97eb4e4b0dba..4fe7dd8c9fa9eb954804555e9615160dc4bc3e8a 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -26,6 +26,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -71,15 +73,20 @@ class XlaDevice : public LocalDevice { // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. + // 'transfer_as_literal' is true if device<->host transfers must be done using + // XLA's TransferLiteral{To,From}Device interface. If false, we can use + // ThenMemcpy instead. static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, - bool register_device_for_compilation, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - ::perftools::gputools::Platform* platform); + ::perftools::gputools::Platform* platform, + bool transfer_as_literal); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -104,7 +111,7 @@ class XlaDevice : public LocalDevice { // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - const DeviceType& jit_device_name_; + DeviceType jit_device_name_; // Memory allocator associated with this device. Allocator* xla_allocator_; // Not owned. ::perftools::gputools::Platform* platform_; // Not owned. @@ -113,9 +120,12 @@ class XlaDevice : public LocalDevice { // copying back and forth between CPU and the device, and // computations enqueued by XLA. xla::Backend::StreamPtr stream_; + // Must we use XLA's transfer manager for correct host<->device transfers? if + // false, we can use ThenMemcpy() instead. + bool transfer_as_literal_; }; -// Builds dummy OpKernel registrations on 'device' for the JIT operators +// Builds OpKernel registrations on 'device' for the JIT operators // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations // object that encapsulates the kernel registrations. struct XlaDeviceOpRegistrations { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index c936222f32056e92efced82d5adb3a96c8041a17..6a57831cde1212671c253ef944e3379770db4a8d 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -26,33 +27,33 @@ namespace se = ::perftools::gputools; namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator(const xla::Backend* backend, - int device_ordinal) - : backend_(backend), device_ordinal_(device_ordinal) {} - +XlaDeviceAllocator::XlaDeviceAllocator() {} XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - se::DeviceMemoryBase dmem = - backend_->memory_allocator() - ->Allocate(device_ordinal_, num_bytes, /*retry_on_failure=*/false) - .ValueOrDie(); - VLOG(2) << "Allocated XLA device tensor " << dmem.opaque() << "(" << num_bytes - << ")"; - return dmem.opaque(); + // We always return an empty XlaTensor object, encoded as an opaque tagged + // pointer. We can return an empty object and ignore num_bytes here because we + // have control over all of the uses of this device tensor, and can lazily + // allocate memory when used. This allows us to also know the shape of the + // allocated Tensor, which is useful if the device's tensor representation + // differs from the host. + return XlaTensor::ToOpaquePointer(new XlaTensor()); } void XlaDeviceAllocator::DeallocateRaw(void* ptr) { - se::DeviceMemoryBase dmem(ptr); - TF_CHECK_OK(backend_->memory_allocator()->Deallocate(device_ordinal_, &dmem)); - VLOG(2) << "Deallocated XLA device tensor " << ptr; + delete XlaTensor::FromOpaquePointer(ptr); } void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {} +XlaTransferManager::XlaTransferManager(se::Stream* stream, + xla::LocalClient* client, + bool transfer_as_literal) + : stream_(stream), + client_(client), + transfer_as_literal_(transfer_as_literal) {} void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -68,18 +69,37 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* dst_ptr = DMAHelper::base(device_tensor); - se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); + if (!xla_tensor->has_shaped_buffer()) { + Status s = xla_tensor->AllocateShapedBuffer( + device_tensor->dtype(), device_tensor->shape(), client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } + + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = xla::Unimplemented( + "XlaTransferManager::CopyCPUTensorToDevice not implemented for " + "literals"); + } else { + stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } + xla_tensor->set_host_tensor(*cpu_tensor); done(status); return; @@ -103,18 +123,24 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, << device_tensor->NumElements(); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* src_ptr = const_cast(DMAHelper::base(device_tensor)); - se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); void* dst_ptr = DMAHelper::base(cpu_tensor); Status status; - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = xla::Unimplemented( + "XlaTransferManager::CopyDeviceTensorToCPU not implemented for " + "literals"); + } else { + stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } done(status); @@ -125,7 +151,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {} +XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, + bool transfer_as_literal) + : manager_(stream, client, transfer_as_literal) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index c4edcd474e48f791af9340c3cd6e4d031407bb68..a8ad511fbd2d7f06601608101b8346ff30f8fc20 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -26,11 +27,12 @@ limitations under the License. namespace tensorflow { -// The allocator used for Tensors assigned to the XLA device. It uses -// XLA backend's allocator. +// The allocator used for Tensors assigned to the XLA device. The allocator +// ignores the alignment and size of the request and always returns a new, +// empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(const xla::Backend* backend, int device_ordinal); + XlaDeviceAllocator(); ~XlaDeviceAllocator() override; string Name() override; @@ -38,18 +40,14 @@ class XlaDeviceAllocator : public Allocator { void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; void GetStats(AllocatorStats* stats) override; - - private: - // Which backend in the client this allocator belongs to. - const xla::Backend* backend_; - // Which hardware device in the client's backend this allocator belongs to. - const int device_ordinal_; }; // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(perftools::gputools::Stream* stream); + explicit XlaTransferManager(perftools::gputools::Stream* stream, + xla::LocalClient* client, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -62,6 +60,10 @@ class XlaTransferManager { // Stream obtained from a Device, used to transfer tensors between // CPU and device. perftools::gputools::Stream* stream_; + // For the underlying memory allocator and XLA's TransferManager. + xla::LocalClient* client_; + // True if we must use XLA's TransferManager for correct device transfers. + bool transfer_as_literal_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -69,7 +71,8 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(perftools::gputools::Stream* stream); + explicit XlaDeviceContext(perftools::gputools::Stream* stream, + xla::LocalClient* client, bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 2326070358d67c0cf30ef17fab5c93862cd8932c..ac60423d959ca44e7d92e2d965cf731287b1f83f 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -34,14 +34,21 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_GPU_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; std::unique_ptr device; - Status status = XlaDevice::Create( - "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device); + Status status = + XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; @@ -55,8 +62,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index a329451b14a785b17913e3838a6571b62b422804..9e098c46f422b436c722bb909dc58930ab7c0ef6 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -41,10 +41,17 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, - options, name_prefix, /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 8322dd2e829a850413f8eee843b78052f6aad549..354be1e1b54b2f2e808b2216cfc1fe110dbb3857 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -52,78 +52,66 @@ std::map SnapshotResourceVariables(OpKernelContext* ctx, return snapshot; } -XlaAllocator::XlaAllocator(const gpu::Platform* platform, - OpKernelContext* op_context) - : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} +XlaAllocator::XlaAllocator(const gpu::Platform* platform, Allocator* wrapped) + : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} -XlaAllocator::~XlaAllocator() = default; +XlaAllocator::~XlaAllocator() {} xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { - AllocatorAttributes allocator_attrs; - allocator_attrs.set_on_host(false); - - AllocationAttributes allocation_attrs; - allocation_attrs.no_retry_on_failure = !retry_on_failure; - - Tensor t; - Status status = op_context_->allocate_temp( - DT_UINT8, TensorShape({static_cast(size)}), &t, allocator_attrs, - allocation_attrs); - if (!status.ok()) { - VLOG(2) << "Allocation failed " << size; - return status; + void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + if (data == nullptr) { + return errors::ResourceExhausted("Out of memory while trying to allocate ", + size, " bytes."); + } else { + return gpu::DeviceMemoryBase(data, size); } - void* data = - reinterpret_cast(const_cast(t.tensor_data().data())); - tensors_[data] = t; - return gpu::DeviceMemoryBase(data, size); -} - -Status XlaAllocator::RegisterArgument(const Tensor* t) { - void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); - tensors_[data] = *t; - return Status::OK(); } Status XlaAllocator::Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) { - if (mem->opaque() != nullptr) { - if (tensors_.erase(mem->opaque()) == 0) { - return tensorflow::errors::InvalidArgument("Unknown tensor address"); - } - } + wrapped_->DeallocateRaw(mem->opaque()); return Status::OK(); } -Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, - DataType dtype, - const TensorShape& shape, - Tensor* out_tensor) const { - void* ptr = const_cast(buffer.opaque()); - auto it = tensors_.find(ptr); - if (it == tensors_.end()) { - return errors::InvalidArgument("Unknown tensor address"); - } - const Tensor& tensor = it->second; - - int64 output_size = DataTypeSize(dtype) * shape.num_elements(); - if (tensor.TotalBytes() == output_size) { - out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); - } else { - Tensor slice = tensor.Slice(0, output_size); - out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); +namespace { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +std::unique_ptr ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator) { + xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_host_shape(), index); + xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_device_shape(), index); + + xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, + shaped_buffer->platform(), + shaped_buffer->device_ordinal()); + + auto& shape_tree = shaped_buffer->buffers(); + auto& sub_shape_tree = sub_shaped_buffer.buffers(); + sub_shape_tree.CopySubtreeFrom(shape_tree, + /*source_base_index=*/{index}, + /*target_base_index=*/{}); + for (auto& index_to_buffer : shape_tree) { + if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) { + index_to_buffer.second = gpu::DeviceMemoryBase(nullptr, 0); + } } - return Status::OK(); + return xla::ScopedShapedBuffer::MakeScoped(&sub_shaped_buffer, allocator) + .ValueOrDie(); } +} // namespace XlaComputationLaunchContext::XlaComputationLaunchContext( int64 num_resource_args, xla::LocalClient* client, - XlaAllocator* xla_allocator) + xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) : num_resource_args_(num_resource_args), client_(client), - xla_allocator_(xla_allocator) {} + xla_allocator_(xla_allocator), + allocate_xla_tensors_(allocate_xla_tensors) {} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, @@ -145,29 +133,32 @@ void XlaComputationLaunchContext::PopulateInputs( t = &(ctx->input(arg_num)); } - gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( - const_cast(t->tensor_data().data()), t->tensor_data().size()); - const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); - CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) - << "On-device shape " - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) - << " not the same as on-host shape " - << xla::ShapeUtil::HumanStringWithLayout(shape); - arg_buffers_[i] = xla::MakeUnique( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), - client_->default_device_ordinal()); - arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs_[i] = arg_buffers_[i].get(); - - OP_REQUIRES_OK(ctx, xla_allocator_->RegisterArgument(t)); + if (xla::ShapeUtil::IsTuple(on_device_shape)) { + const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); + arg_ptrs_[i] = + const_cast(&xla_tensor->shaped_buffer()); + } else { + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + gpu::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); + arg_buffers_[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, + client_->platform(), client_->default_device_ordinal()); + arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = arg_buffers_[i].get(); + } } } void XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - std::unique_ptr output) { + std::unique_ptr output) { gpu::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -180,36 +171,59 @@ void XlaComputationLaunchContext::PopulateOutputs( // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); if (kernel->outputs[i].is_constant) { // Output is a constant. const Tensor& const_tensor = kernel->outputs[i].constant_value; + Tensor* output_tensor; const size_t total_bytes = const_tensor.TotalBytes(); if (stream && total_bytes > 0) { // Copy host -> device. (Empty tensors don't have backing buffers.) VLOG(1) << "Constant output tensor on device"; - Tensor* output_tensor; - TF_CHECK_OK( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + + OP_REQUIRES_OK( + ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( + const_tensor.dtype(), const_tensor.shape(), + client_, stream->parent()->device_ordinal())); + } const void* src_ptr = DMAHelper::base(&const_tensor); - void* dst_ptr = DMAHelper::base(output_tensor); - gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); + gpu::DeviceMemoryBase dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*output_tensor); + // Memcpying asynchronously is safe for the GPU, but the CPU uses a + // shared allocator so hold a reference to the copied-to buffer until + // complete. + TensorReference ref(*output_tensor); + stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes); + stream->ThenDoHostCallback([ref] { ref.Unref(); }); } else { // No copy required. ctx->set_output(i, const_tensor); + output_tensor = ctx->mutable_output(i); + } + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + xla_tensor->set_host_tensor(const_tensor); } } else { const TensorShape& shape = kernel->outputs[i].shape; VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - Tensor output_tensor; - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK(ctx, xla_allocator_->MakeTensorFromBuffer( - buffer, ctx->expected_output_dtype(i), shape, - &output_tensor)); - ctx->set_output(i, output_tensor); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer( + ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + ctx->set_output(i, output_tensor); + } ++output_num; } @@ -221,6 +235,7 @@ void XlaComputationLaunchContext::PopulateOutputs( // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; for (int i = 0; i < kernel->resource_updates.size(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), @@ -244,10 +259,21 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, errors::Internal("Mismatched type in variable write")); - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK(ctx, - xla_allocator_->MakeTensorFromBuffer( - buffer, write.type, write.shape, variable->tensor())); + if (allocate_xla_tensors_) { + Tensor output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer( + ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + *variable->tensor() = output_tensor; + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + write.type, write.shape, buffer, allocator); + output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + *variable->tensor() = output_tensor; + } ++output_num; } } diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 9fd356fce5896c317196cb31fd5248b6bc3427a8..14f70fe35891040ff3460567adb223be0f1c910f 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -19,8 +19,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" @@ -45,24 +47,13 @@ std::map SnapshotResourceVariables(OpKernelContext* ctx, class XlaAllocator : public xla::DeviceMemoryAllocator { public: XlaAllocator(const perftools::gputools::Platform* platform, - OpKernelContext* op_context); + Allocator* wrapped); ~XlaAllocator() override; xla::StatusOr Allocate( int device_ordinal, uint64 size, bool retry_on_failure) override; Status Deallocate(int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; - // Register an Tensor (input or resource variable) with the allocator. If - // the operation returns an alias to one of its inputs, then the allocator - // needs to be able to handle it. - Status RegisterArgument(const Tensor* t); - - // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is - // interpreted as having data type 'dtype' and shape 'shape'. - Status MakeTensorFromBuffer(perftools::gputools::DeviceMemoryBase buffer, - DataType dtype, const TensorShape& shape, - Tensor* out_tensor) const; - // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main // compute stream to enforce a happens-before relationship between a memory @@ -73,20 +64,19 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { bool AllowsAsynchronousDeallocation() const override { return true; } private: - OpKernelContext* const op_context_; - - // Map from pointer address to the owning Tensor; used by - // MakeTensorFromBuffer. Also used to automatically release Tensors when the - // allocator is freed. - std::unordered_map tensors_; + Allocator* wrapped_; }; // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. class XlaComputationLaunchContext { public: + // Create a new launch context. 'allocate_xla_tensors' is true if allocated + // output tensors and variables are always XlaTensors. If false they are + // assumed to be "normal" device pointers. XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, - XlaAllocator* xla_allocator); + xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -97,7 +87,7 @@ class XlaComputationLaunchContext { // Given the XLA output in `output`, populate all outputs of `ctx`. void PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - std::unique_ptr output); + std::unique_ptr output); // Return the argument list. Only valid after PopulateInputs() has been // called. @@ -106,11 +96,53 @@ class XlaComputationLaunchContext { private: int64 num_resource_args_; xla::LocalClient* client_; - XlaAllocator* xla_allocator_; + xla::DeviceMemoryAllocator* xla_allocator_; + bool allocate_xla_tensors_; std::vector> arg_buffers_; std::vector arg_ptrs_; }; +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class XlaTensorBuffer : public TensorBuffer { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) + : expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) { + data_ = const_cast(ptr); + } + + ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + + void* data() const override { return data_; } + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(actual_size_); + } + + static Tensor MakeTensor(DataType dtype, const TensorShape& shape, + perftools::gputools::DeviceMemoryBase buffer, + Allocator* allocator) { + size_t expected_size = shape.num_elements() * DataTypeSize(dtype); + auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, + buffer.size(), allocator); + Tensor t(dtype, shape, tensor_buffer); + tensor_buffer->Unref(); + return t; + } + + private: + void* data_; + size_t expected_size_; + size_t actual_size_; + Allocator* allocator_; +}; + } // namespace tensorflow #endif diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc new file mode 100644 index 0000000000000000000000000000000000000000..956328e6757f4c903e3995a54635682d19052794 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" + +namespace tensorflow { + +/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { + if (tensor->NumElements() == 0) { + return nullptr; + } + XlaTensor* xla_tensor = + FromOpaquePointer(const_cast(tensor->tensor_data().data())); + return xla_tensor; +} + +/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { + return FromTensor(const_cast(tensor)); +} + +/*static*/ perftools::gputools::DeviceMemoryBase +XlaTensor::DeviceMemoryFromTensor(const Tensor& tensor) { + const XlaTensor* xla_tensor = FromTensor(&tensor); + if (xla_tensor) { + CHECK(xla_tensor->has_shaped_buffer()); + return xla_tensor->shaped_buffer().root_buffer(); + } else { + return perftools::gputools::DeviceMemoryBase( + const_cast(tensor.tensor_data().data()), + tensor.tensor_data().size()); + } +} + +Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, + int device_ordinal) { + xla::Shape on_host_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); + xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape( + on_host_shape); + + xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), + device_ordinal); + for (auto& index_to_buffer : buffer.buffers()) { + xla::Shape subshape = + xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); + uint64 size = + client->backend().transfer_manager()->GetByteSizeRequirement(subshape); + TF_ASSIGN_OR_RETURN(index_to_buffer.second, + client->backend().memory_allocator()->Allocate( + device_ordinal, size, /*retry_on_failure=*/false)); + } + + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + xla::ScopedShapedBuffer::MakeScoped( + &buffer, client->backend().memory_allocator())); + set_shaped_buffer(std::move(scoped_buffer)); + return Status::OK(); +} + +// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from +// device-side tensors, which are either CPU or GPU memory pointers. This works +// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. +namespace { +constexpr uintptr_t kTag = 0x1ULL; +} + +/*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) { + uintptr_t value = reinterpret_cast(ptr); + if (value & kTag) { + return reinterpret_cast(value & ~kTag); + } else { + return nullptr; + } +} + +/*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) { + uintptr_t value = reinterpret_cast(tensor); + CHECK_EQ(value & kTag, 0); + value |= kTag; + return reinterpret_cast(value); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff2fb08f03548260215c6aeded2c124f8d28f43 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The implementation of a Tensor for an XlaDevice. All device tensors are +// actually one of these. +// +// To distinguish between "normal" device tensors and XlaTensors, the raw +// pointer data stored in the TensorBuffer is a tagged pointer. +class XlaTensor { + public: + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static XlaTensor* FromTensor(Tensor* tensor); + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static const XlaTensor* FromTensor(const Tensor* tensor); + + // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in + // which case the returned value is shaped_buffer()->root_buffer(), or a + // normal Tensor in which case the returned value is + // {tensor.tensor_data().data(), tensor.tensor_data().size}. + static perftools::gputools::DeviceMemoryBase DeviceMemoryFromTensor( + const Tensor& tensor); + + // Assign the internal ShapedBuffer to new memory for the given dtype and + // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it + // is replaced and the managed memory deallocated. + Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, int device_ordinal); + + // Some Tensors can have complex on-device shapes, including tuple shapes. To + // manage the memory for these tensors a ShapedBuffer may be required. + + // Return true if this TensorInfo contains a ShapedBuffer. + bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } + // Return the contained ShapedBuffer. + // REQUIRES: has_shaped_buffer() + const xla::ShapedBuffer& shaped_buffer() const { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + // Mutates the TensorInfo to set the ShapedBuffer. + void set_shaped_buffer( + std::unique_ptr shaped_buffer) { + shaped_buffer_ = std::move(shaped_buffer); + } + + // Some tensors on the device may have known values on the host. We use these + // in on-demand mode to avoid re-copying values from the device if we know the + // host value already. + + // Return true if this TensorInfo contains a host tensor. + bool has_host_tensor() const { return host_tensor_ != nullptr; } + // Return the contained host tensor. + // REQUIRES: has_host_tensor() + const Tensor& host_tensor() const { return *host_tensor_; } + // Sets the contained host tensor. + void set_host_tensor(const Tensor& tensor) { + host_tensor_.reset(new Tensor(tensor)); + } + + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. + static XlaTensor* FromOpaquePointer(void* ptr); + // Convert to a raw pointer from an XlaTensor, adding the pointer tag. + static void* ToOpaquePointer(XlaTensor* tensor); + + private: + // The optional contained ShapedBuffer. + std::unique_ptr shaped_buffer_; + // An optional host tensor value. + std::unique_ptr host_tensor_; +}; + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index da4bc44c7a75c9f8faf16c537a17a1f2d16d5d61..238fd15166c0b08ee109d6a3888e16c39f87a603 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -49,17 +49,3 @@ cc_library( "//tensorflow/compiler/jit:xla_device", ], ) - -#----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 85a2adab283c273af607b6b80fb4fd76f8dac2b2..edabdc218a3d8782d524aee01833db3179cafbc9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -86,7 +86,10 @@ tf_xla_py_test( # ArgMax needs CustomCall on CPU, which is not available in normal # (not precompiled) TensorFlow. The flag below excludes the CPU # backend. - disabled_backends = "cpu", + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -315,6 +318,8 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + # Functions are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -537,7 +542,6 @@ tf_xla_py_test( size = "medium", srcs = ["spacetobatch_op_test.py"], shard_count = 3, - tags = ["notsan"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -551,6 +555,8 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + # Stack ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -577,6 +583,8 @@ tf_xla_py_test( name = "tensor_array_ops_test", size = "small", srcs = ["tensor_array_ops_test.py"], + # TensorArray ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -827,17 +835,3 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index ba7b9bacd2b794c74409d517a9c05bfbb14a845f..d1d7379c0a32eff4ff96e791dacbe800bbd70b7d 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -190,19 +190,24 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) - self._testBinary( - gen_nn_ops.sparse_softmax_cross_entropy_with_logits, - np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], - [0.9, 1.0, 1.1, 1.2]], dtype=dtype), - np.array([2, 1, 7], dtype=np.int32), - expected=[ - np.array([1.342536, 1.442536, np.nan], dtype=dtype), - np.array([[0.213838, 0.236328, -0.738817, 0.288651], - [0.213838, -0.763672, 0.261183, 0.288651], - [np.nan, np.nan, np.nan, np.nan]], - dtype=dtype), - ], - equality_test=self.ListsAreClose) + # TODO(b/68813416): Fails with bfloat16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits, + np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]], + dtype=dtype), + np.array([2, 1, 7], dtype=np.int32), + expected=[ + np.array([1.342536, 1.442536, np.nan], dtype=dtype), + np.array( + [[0.213838, 0.236328, -0.738817, 0.288651], [ + 0.213838, -0.763672, 0.261183, 0.288651 + ], [np.nan, np.nan, np.nan, np.nan]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) def testIntOps(self): for dtype in self.int_types: @@ -260,12 +265,6 @@ class BinaryOpsTest(XLATestCase): np.array([[1], [2]], dtype=dtype), dtype(7), expected=np.array([[8], [9]], dtype=dtype)) - self._testBinary( - math_ops.add, - np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), - np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), - expected=np.array( - [1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) self._testBinary( math_ops.subtract, @@ -361,6 +360,12 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), + np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), + expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) + def testComplexOps(self): for dtype in self.complex_types: ctypes = {np.complex64: np.float32} diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 0528a5415d579a844e68403ace1bb8982a10a841..a9db1c173d33b0bc44248a4b55c678f7083b5527 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -56,7 +56,7 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: @@ -89,4 +89,3 @@ def generate_backend_suites(backends=[]): backends = all_backends() for backend in backends: native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) - diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3..1a8989d7c2f617525c301f30fd899a01362310bf 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -34,6 +34,13 @@ from tensorflow.python.platform import test class CholeskyOpTest(XLATestCase): + # Cholesky defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/cholesky) + @property + def float_types(self): + return set(super(CholeskyOpTest, self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) self.assertAllClose(x, verification_np, atol=atol) diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 2d8236e2cbdfafb35626cd582ee39b1f917aec7f..f9d87c2d1cfe5c1a7487e124c971a54ffcfede15 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.compiler import jit @@ -436,5 +437,55 @@ class XlaCompilationTest(test.TestCase): self.assertTrue(InLabels(labels, "_XlaLaunch")) +class ElementWiseFusionTest(test.TestCase): + + # Runs a simple test with the input jit_level and fusion_only flag. + def simpleTest(self, arg0, arg1, global_jit_level): + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = global_jit_level + + with session_lib.Session(config=config) as sess: + a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") + a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") + # Two element-wise ops. We need at least two ops since single + # element clusters are not passed to XLA in fusion_only mode. + a3 = a1 * a2 + a4 = a3 + a1 + # A matmul to break XLA clustering. + a5 = math_ops.matmul(a4, a1) + # Two more element-wise ops. + a6 = a5 - a4 + a7 = a6 + a2 + + run_metadata = config_pb2.RunMetadata() + output = sess.run( + a7, { + a1: arg0, + a2: arg1 + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = RunMetadataLabels(run_metadata) + count = sum("_XlaLaunch(" in x for x in labels) + + return output, count + + def testElementWiseClustering(self): + arg0 = np.random.rand(2, 2).astype(np.float32) + arg1 = np.random.rand(2, 2).astype(np.float32) + os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true" + tf_op, tf_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.OFF) + self.assertEqual(0, tf_count) + + tfef_op, tfef_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.ON_1) + self.assertEqual(2, tfef_count) + + self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd..5819b2bf2b55b9213a039c0ba82dd0bf1c738b00 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -37,6 +37,14 @@ def MakePlaceholder(x): class MatrixTriangularSolveOpTest(XLATestCase): + # MatrixTriangularSolve defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) + @property + def float_types(self): + return set(super(MatrixTriangularSolveOpTest, + self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol): diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e72dd4eea9f127e1df96ab166103c4c16372adb6..e53efc3091d8935e745122af29abd7b8063b1d01 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64, DT_INT64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 92518aadc4bf5c601cfb4192c093799784b6aa72..60839814931eaeb0b78a20fd1e4f387d241cd56f 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -156,6 +157,12 @@ class SpaceToBatchNDTest(XLATestCase): paddings = np.array(paddings).reshape((len(block_shape), 2)) with self.test_session() as sess, self.test_scope(): for dtype in self.float_types: + # TODO(b/68813416): Skip bfloat16's as the input type for direct is + # float32 and results in a mismatch, while making testDirect provide the + # correctly typed input results in 'no fill-function for data-type' + # error. + if dtype == dtypes.bfloat16.as_numpy_dtype: + continue placeholder = array_ops.placeholder(dtype) # outputs = space_to_batch(inputs) x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3d3e112f4821ea8e57ea9589a5b4433647ad294b..17149aa1c8edddadc504e916915a70f78abf8002 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -600,6 +600,20 @@ class UnaryOpsTest(XLATestCase): src, expected=dst) + def testBitcast(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1, 0x3f800000], np.int32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.float32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1e-45, 1.0], np.float32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1e-45, 1.0], np.float32), + expected=np.array([1, 0x3f800000], np.int32)) + def testInvertPermutation(self): self._assertOpOutputMatchesExpected( array_ops.invert_permutation, @@ -779,7 +793,10 @@ class UnaryOpsTest(XLATestCase): self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) self._assertSoftplusMatchesExpected( [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) - log_eps = np.log(np.finfo(dtype).eps) + if dtype == dtypes.bfloat16.as_numpy_dtype: + log_eps = np.log(np.finfo(np.float32).eps) + else: + log_eps = np.log(np.finfo(dtype).eps) one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index b08d6ab21e0746558cb3d4818d4c822c45d2e9ee..8ecad00f6e23b3a7746bbb473102ac847bf4cbfd 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -230,7 +230,10 @@ class SliceAssignTest(XLATestCase): # shrink shape changes checker[1:2, 1] = [66] checker[1, 1:2] = [66] - checker[1, 1] = 66 + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker[1, 1] = 66 # newaxis shape changes checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] # shrink and newaxis @@ -243,8 +246,11 @@ class SliceAssignTest(XLATestCase): # Assign vector to scalar (rank-0) using newaxis checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) - checker2[()] = 6 # no indices - checker2[...] = 6 # ellipsis + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis checker2[None] = [6] # new axis def testUninitialized(self): diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index cc778f1c3c0098da5ab933f9b4674890a724d160..e924fe1e61454aefda622a5a46a0e483d26db5c1 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import contextlib +import os import random import re @@ -44,6 +45,8 @@ flags.DEFINE_string('test_device', None, flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') flags.DEFINE_string('disabled_manifest', None, 'Path to a file with a list of tests that should not run.') +flags.DEFINE_string('tf_xla_flags', None, + 'Value to set the TF_XLA_FLAGS environment variable to') class XLATestCase(test.TestCase): @@ -97,6 +100,8 @@ class XLATestCase(test.TestCase): disabled_tests = [] disabled_method_types = [] for l in manifest_file.read().splitlines(): + if not l: + continue entry = comments_re.sub('', l).strip().split(' ') if len(entry) == 1: disabled_tests.append(entry[0]) @@ -113,6 +118,9 @@ class XLATestCase(test.TestCase): for name in types]) manifest_file.close() + if FLAGS.tf_xla_flags is not None: + os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags + @property def all_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index eb20ca501c80b01c76198e1ad54173f1c601714d..e7daf4e01c45c3705216fce7dd3db5baa0c261fc 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -332,6 +332,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -462,17 +463,3 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 311dddca94c458a60fd00afe5532840e0dbf0437..c30bb9cacd48fb93ac359a6a25699ba6a74183c5 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -51,17 +51,3 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 6f46532419d3389bafe8c3bf41fa41e8a3e173b7..de1008803d69fefa415c7bdbe6c27a62e625b417 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -55,8 +55,10 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_args->at(index) = true; return; } - for (const Node* pred : node->in_nodes()) { - must_be_const.insert(pred); + for (const Edge* pred : node->in_edges()) { + if (!pred->IsControlEdge()) { + must_be_const.insert(pred->src()); + } } return; } diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 9d125f8d499863cfaa0e26b5b633ca02914d1b7d..992b12c06db5efc0ae54284d0ea77017c1c79aca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -79,5 +79,24 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } +TEST(ConstAnalysisTest, DontFollowControlDependencies) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg1, c1); + Output reshape = ops::Reshape(root, arg1, add); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + EXPECT_EQ(const_args, std::vector({false, true})); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d2fa933cf9c085f92b2f442827a94d72938e4bb2..f1bc7d6af49a09f84ef251eaa1c3d684792d0c1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -93,6 +93,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -154,6 +155,22 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "if_op", + srcs = ["if_op.cc"], + hdrs = ["if_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. tf_kernel_library( @@ -200,17 +217,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a249b1869f547f8e5aa725f9f5cf391b10429928..931175be1111ed5f70afbdf351ee53c59c1367de 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -118,30 +118,24 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); - - auto grad_backprop = ctx->Input(0); - auto activations = ctx->Input(1); - auto scale = ctx->Input(2); - auto mean = ctx->Input(3); - auto var = ctx->Input(4); - - TensorShape input_shape = ctx->InputShape(0); - int feature_index = - GetTensorFeatureDimIndex(input_shape.dims(), data_format_); - + xla::ComputationBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); - xla::PrimitiveType input_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); - xla::PrimitiveType scale_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). - grad_backprop = b->ConvertElementType(grad_backprop, scale_type); - activations = b->ConvertElementType(activations, scale_type); + auto grad_backprop = + XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + auto activations = + XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + auto scale = ctx->Input(2); + auto mean = ctx->Input(3); + auto var = ctx->Input(4); + + const int input_dims = ctx->InputShape(0).dims(); + const int feature_index = + GetTensorFeatureDimIndex(input_dims, data_format_); xla::ComputationDataHandle x_backprop; xla::ComputationDataHandle scale_backprop; @@ -156,7 +150,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = b->GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. - std::vector reduction_dims(input_shape.dims() - 1); + std::vector reduction_dims(input_dims - 1); std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, 0); std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), @@ -165,9 +159,14 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + // epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - offset_backprop = - b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), - *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(scale_dtype); + auto converted = + XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -175,17 +174,21 @@ class FusedBatchNormGradOp : public XlaOpKernel { b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) - auto scratch2 = b->Reduce( - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), - XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), - reduction_dims); + auto mul = + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(0, + XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index cbade79e85eed10ecb5ead7151ee778c86a0de37..569950c2dfaeb61028049a263a962dfa54a62e09 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -184,9 +184,7 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace") - .CompileTimeConstInput("crops") - .CompileTimeConstInput("block_shape"), +REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"), BatchToSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index c667b4e3e326b776faba49387760abbd582fcc68..ed33b8ed2e823f313a9a7fe220390bc617288405 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -103,10 +103,15 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationDataHandle result = ctx->builder()->Reduce( - ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), reduce_dims); - ctx->SetOutput(0, result); + xla::ComputationBuilder* const b = ctx->builder(); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 43a6a747c6bcc441f33f276fde4a66f367d99731..c52b2dcb7e9ef81fd52565dfbda05e33a52ed43a 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -62,5 +62,50 @@ class CastOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Cast"), CastOp); +class BitcastOp : public XlaOpKernel { + public: + explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle output; + + if (src_dtype_ == dst_dtype_) { + output = input; + } else { + // The only complex type in XLA is C64, so error out if the bitcast has a + // complex source or destination type and the bitcast is not trivial. + OP_REQUIRES(ctx, + !xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_), + errors::Unimplemented("Complex types not supported.")); + // XLA bitcast requires that the bit-width of the source and destination + // matches, and currently only the simple lowering is performed. + OP_REQUIRES(ctx, + xla::primitive_util::BitWidth(src_type_) == + xla::primitive_util::BitWidth(dst_type_), + errors::Unimplemented( + "Only bitcasts between equally sized types supported.")); + output = builder->BitcastConvertType(input, dst_type_); + } + + ctx->SetOutput(0, output); + } + + protected: + DataType src_dtype_, dst_dtype_; + xla::PrimitiveType src_type_, dst_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp); +}; + +REGISTER_XLA_OP(Name("Bitcast"), BitcastOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 81cea6d376d02c956a5257c5475fe5c10b83deb9..c0ee0c9c2ea849a692bee70bba36d32335eed9b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -58,7 +58,7 @@ xla::ComputationDataHandle CreateExpandedZero( // Create a mask for depthwise convolution that will make a normal convolution // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// depthwise filter this returns a [2, 2, 3, 6] tensor // 1 1 0 0 0 0 1 1 0 0 0 0 // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 @@ -166,6 +166,10 @@ xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); return builder->Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with + // ExpandedZero guarantees that only one element is non zero, so there + // cannot be accumulated precision error. builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), *ctx->GetOrCreateAdd(dtype), {expanded_filter_shape.dims() - 2}), diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 453a32c494b42e9922bc35fc526f3306530054fd..99470d70e709ddb5593c5eaae061bb897befc168 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -247,6 +247,8 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { const TensorShape gradient_shape = ctx->InputShape(0); xla::ComputationDataHandle input = ctx->Input(1); const DataType data_type = ctx->input_type(1); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(data_type); xla::ComputationDataHandle input_min = ctx->Input(2); xla::ComputationDataHandle input_max = ctx->Input(3); @@ -265,15 +267,23 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { ctx->SetOutput(0, output0); xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes); + xla::ComputationDataHandle reduce1 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output1 = - b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes); + xla::ComputationDataHandle reduce2 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output2 = - b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eefbe55c815d80a608bdf62d454a69d722adb158 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -0,0 +1,226 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/if_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr)); + then_branch_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr)); + else_branch_ = *name_attr; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaIfOp::Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + + OP_REQUIRES(ctx, cond_type_ == DT_BOOL, + errors::InvalidArgument( + "Condition argument must be a boolean for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "Condition argument must be a scalar for XLA compilation")); + + VLOG(1) << "Building If: " << input_types_.size() << " inputs"; + + std::vector inputs(input_types_.size()); + std::vector arguments(input_types_.size()); + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + inputs[i] = ctx->Input(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString(); + } + } + + // Compile both branches of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult then_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + XlaCompiler::CompilationResult else_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + } + } + + // Check that both branches have identical input shapes. + OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), + errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape))); + + // Check that both branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape), + errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + then_result.resource_updates.size() == + else_result.resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in then and else branch")); + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES( + ctx, equal, + errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i)); + } + + xla::ComputationDataHandle outputs = + b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, + b->Tuple(inputs), *else_result.computation); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (int i = 0; i < result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result->resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = result->outputs.size() + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + b->GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "If variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building If"; +} + +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f9bc98a198a72dcc0594e61971713bf890ce30b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp); + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index f22f384256a8ddd8c05de4a1322aba741dc4d7fd..5eeda79a935e8194a596d322b52add27846d378c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -180,9 +180,13 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); - auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), - /*computation=*/*context->GetOrCreateAdd(type), + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); + auto output = XlaHelpers::ConvertElementType(b, reduce, type); output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index d096415087e47a73503a06526ab133ac34803c5d..c177f08d9c4687bb13b98a4328bb3960519799c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -29,21 +29,22 @@ class L2LossOp : public XlaOpKernel { explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); + std::vector dims(ctx->InputShape(0).dims()); + std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - - auto zero = XlaHelpers::Zero(b, dtype); - auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - const xla::Computation& add = *ctx->GetOrCreateAdd(dtype); - - std::vector dims(input_shape.dims()); - std::iota(dims.begin(), dims.end(), 0); + xla::ComputationBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 - auto x = ctx->Input(0); - ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two)); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto t = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto square = b->Mul(t, t); + auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); + ctx->SetOutput(0, b->Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 759d1a1a2d996d4f5deb1774be7014bb6de30f40..1cfee3070f384af0a7441a9c860c530dd1b42187 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -47,12 +47,17 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. - auto squared = builder->Mul(input, input); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto scale = builder->Pow( builder->Add(builder->ConstantR0(bias_), @@ -130,12 +135,17 @@ class LRNGradOp : public XlaOpKernel { // dyi *= out_grads[j] // grads[k] += dyi - auto squared = builder->Mul(in_image, in_image); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = builder->Add(builder->ConstantR0(bias_), @@ -146,11 +156,15 @@ class LRNGradOp : public XlaOpKernel { builder->Div(out_image, norm)), in_grads); - auto dy_reduced = builder->ReduceWindow( - dy, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto converted_dy = + XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto dy_reduce = builder->ReduceWindow( + converted_dy, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto dy_reduced = + XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); xla::ComputationDataHandle gradients = builder->Add( builder->Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 086a9491aa93ebfae99f296dd355ae2e322084ec..5f635dd1bc6122cfcac8163baafd95b13f157715 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -35,8 +35,11 @@ namespace { // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: - PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) - : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims, + const DataType reduction_type) + : XlaOpKernel(ctx), + num_spatial_dims_(num_spatial_dims), + reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { std::vector ksize_int; std::vector stride_int; @@ -63,12 +66,10 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) = 0; + virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) = 0; + virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. virtual xla::ComputationDataHandle PostProcessOutput( @@ -76,9 +77,6 @@ class PoolingOp : public XlaOpKernel { DataType dtype, const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - std::vector ksize = ksize_; std::vector stride = stride_; if (ctx->num_inputs() != 1) { @@ -106,16 +104,20 @@ class PoolingOp : public XlaOpKernel { stride.clear(); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); } + const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), " dimensions")); - const DataType type = input_type(0); - xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, - stride, padding_); - ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); + xla::ComputationBuilder* const b = ctx->builder(); + auto input = + XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); + auto reduce = ctx->builder()->ReduceWindow( + input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + ctx->SetOutput(0, + PostProcessOutput(ctx, pooled, input_type(0), input_shape)); } protected: @@ -124,21 +126,21 @@ class PoolingOp : public XlaOpKernel { std::vector stride_; xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; + DataType reduction_type_; }; class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::MinValue(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateMax(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateMax(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -209,15 +211,17 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( } // Build a matrix of all 1s, with the same width/height as the input. + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto ones = ctx->builder()->Broadcast( - XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); + XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto counts = ctx->builder()->ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), dtype), - *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, + auto reduce = ctx->builder()->ReduceWindow( + ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); + auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); return ctx->builder()->Div(output, counts, window_dims); } @@ -226,16 +230,16 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateAdd(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateAdd(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -455,14 +459,12 @@ class AvgPoolGradOp : public XlaOpKernel { gradients_shape, filter_shape, out_backprop_shape, stride_, padding_, data_format_, &dims)); + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + xla::ComputationBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - - // The input gradients are computed by a convolution of the output - // gradients - // and the filter, with some appropriate padding. See the comment at - // the top of conv_grad_ops.h for details. - DataType dtype = input_type(1); - + auto dtype = input_type(1); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; @@ -483,17 +485,18 @@ class AvgPoolGradOp : public XlaOpKernel { padding->set_interior_padding(dims.spatial_dims[i].stride - 1); } - auto zero = XlaHelpers::Zero(ctx->builder(), dtype); - auto padded_gradients = - ctx->builder()->Pad(out_backprop_div, zero, padding_config); + auto zero = XlaHelpers::Zero(b, dtype); + auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); - xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( - padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, + auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto in_backprop = b->ReduceWindow( + XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), ksize_, /* window_strides=*/ones, xla::Padding::kValid); - - ctx->SetOutput(0, in_backprop); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 03b13b2924f4b81c1017804c91d5ffb81c44ea0b..812d258cd1677e18ef49952044126c76a2f55b19 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -27,7 +27,13 @@ namespace { class SumOp : public XlaReductionOp { public: - explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit SumOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -39,11 +45,13 @@ REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: - explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit ProdOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - return XlaHelpers::One(builder, input_type(0)); + return XlaHelpers::One(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -58,13 +66,12 @@ REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), class MinOp : public XlaReductionOp { public: - explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MinOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MaxValue(type)); + return XlaHelpers::MaxValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -78,13 +85,12 @@ REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: - explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MaxOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MinValue(type)); + return XlaHelpers::MinValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -98,8 +104,14 @@ REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: - explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MeanOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -121,7 +133,8 @@ REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), class AllOp : public XlaReductionOp { public: - explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AllOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { @@ -139,7 +152,8 @@ REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: - explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AnyOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 9aca6d8fedf92f176b3b7b40c5961d4a2e557a8a..f3181f0dadc2d3f45abb145e009e2663c10490f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -33,12 +33,12 @@ namespace tensorflow { // xla::ComputationBuilder. class XlaReductionOp : public XlaOpKernel { public: - explicit XlaReductionOp(OpKernelConstruction* ctx); + XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} - // Return the base case for the reduction. Defaults to zero. + // Return the base case for the reduction. virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder); + xla::ComputationBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired @@ -63,6 +63,9 @@ class XlaReductionOp : public XlaOpKernel { private: // True if the number of dimensions should be maintained. bool keep_dims_; + + protected: + DataType reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4b5d09eb9fd4110cdc4221099ff55767e9132540..64fe765ae9a945c58ea60bc157b1520c83b0d8e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -24,19 +24,15 @@ limitations under the License. namespace tensorflow { -XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { +XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, + DataType reduction_type) + : XlaOpKernel(ctx), reduction_type_(reduction_type) { const DataType dt = BaseType(input_type(0)); OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); } -// Return the base case for the reduction. Defaults to zero. -xla::ComputationDataHandle XlaReductionOp::InitialValue( - xla::ComputationBuilder* builder) { - return XlaHelpers::Zero(builder, input_type(0)); -} - // Unless BuildFinalizer is overridden the reduction has no // finalizer. xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( @@ -100,36 +96,26 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - // Call virtual method to get the initial value. - const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); + xla::ComputationBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(ctx->builder()->client(), - strings::StrCat(desc, "-reduction")); + xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - // Make two scalar parameters of the desired type for the lambda. - xla::ComputationDataHandle rx = - r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - xla::ComputationDataHandle ry = - r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); - - auto data = ctx->Input(0); + TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); + auto data = b->ConvertElementType(ctx->Input(0), type); + // Call virtual method to get the initial value. + auto initial = b->ConvertElementType(InitialValue(b), type); + // Make two scalar parameters of the desired type for the lambda. + auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle reduce = - ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - xla::ComputationDataHandle finalized = - BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); - - xla::ComputationDataHandle result; - if (keep_dims_) { - result = ctx->builder()->Reshape(finalized, final_shape); - } else { - result = finalized; - } + auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f..4cfa28a0ce3d7d1f24196ef6ef2775f840b2bcf1 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -66,7 +66,7 @@ class ScanOp : public XlaOpKernel { -input_shape.dims(), ", ", input_shape.dims(), "), but got ", axis)); - DataType dtype = ctx->input_type(0); + DataType dtype = XlaHelpers::SumAccumulationType(ctx->input_type(0)); if (input_shape.num_elements() == 0) { // Exit early if there is nothing to compute. @@ -91,7 +91,6 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle input = ctx->Input(0); xla::ComputationDataHandle init; const xla::Computation* reducer; if (sum_) { @@ -102,7 +101,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = builder->ReduceWindowWithGeneralPadding( - ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, + *reducer, window_dims, window_strides, padding); + output = + XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 80d6df6c48b0141734dcee1c2a3c413926931feb..498342a98881df0c6ff50007eacc1d5ef6196b57 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -83,7 +83,9 @@ class UnsortedSegmentSum : public XlaOpKernel { DataType dtype_; }; -REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); +REGISTER_XLA_OP( + Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), + UnsortedSegmentSum); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 750a4c2dec8154f97f307978b3d8884271292279..463788b8b461c370a8e7ab4d79a94fc0143b8b45 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -28,7 +29,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = StringPiece(type_string()).starts_with("Log"); + log_ = str_util::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { @@ -42,9 +43,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationBuilder* const b = ctx->builder(); const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -52,21 +52,20 @@ class SoftmaxOp : public XlaOpKernel { // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - xla::ComputationDataHandle softmax; - if (log_) { - // softmax = shifted_logits - log(sum(exp(shifted_logits))) - auto log_sum_exp = - b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), - add_func, {kClassDim})); - softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - } else { - // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - auto exp_shifted = b->Exp(shifted_logits); - auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, - {kClassDim}); - softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); - } - + auto exp_shifted = b->Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : b->Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -82,7 +81,6 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, const xla::ComputationDataHandle& logits, const xla::ComputationDataHandle& labels) { const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); const int kBatchDim = 0; const int kClassDim = 1; @@ -100,8 +98,12 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, auto exp_shifted_logits = b->Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) - auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), - add_func, {kClassDim}); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = b->Log(sum_exp); @@ -110,9 +112,13 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - xla::ComputationDataHandle loss = b->Reduce( - b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), - XlaHelpers::Zero(b, type), add_func, {kClassDim}); + auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = b->Mul(b->Neg(labels), sub); + auto sum = + b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index b10880de77e6b9811008076cd4a959c284e558d1..5bb773d97fc5ce90dabceeefd5c29d916597f5ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -239,6 +239,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); @@ -272,6 +273,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 488fda74bf7b5c1d66f8d706a1be3cc1fc29a492..344773c8c5f8e1a552d585d0317c62c56d9f9d46 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -140,17 +140,3 @@ cc_library( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 98f72b3792eb147f5a1847c5e1ecef18bccbca5f..aeb743a6634673f2e8c4dee9ae1e5017944aae2c 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -39,17 +39,3 @@ tf_gen_op_wrapper_py( ":sendrecv_ops", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 1a0e09758f7cc6714793300c6ece14093a8ad246..5759c72af301785f3ca1110b58eeb2fe7dead713 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +66,8 @@ ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !StringPiece(parsed_device.type) - .contains(kDeviceSuffixReplicatedCore)) { + !str_util::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ed10d80609641b090cf78bf2e17364fe2fa89c31..ae51446204baf14dc03fc6305641048dbf3872b0 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0dc5118c9c659cc1529515f34c9eb43fd07a69e8..86263d847ae02d50e70dafb0129b2664c522f2a3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -600,6 +600,48 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } +Status XlaCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& options, string const& name, + OpKernelContext* ctx, const std::vector& args, + CompilationResult* result) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + Status status = NodeBuilder(name, "_Arg") + .ControlInput(graph->source_node()) + .Attr("T", ctx->input_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < ctx->num_outputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + Status status = NodeBuilder(name, "_Retval") + .Input(main_node, i) + .Attr("T", ctx->expected_output_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + + return CompileGraph(options, name, std::move(graph), args, result); +} + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -718,8 +760,8 @@ Status XlaCompiler::GetChannelHandle(const string& key, namespace { -void SetTransfer(const string& key, const std::vector& types, - const std::vector& shapes, +void SetTransfer(const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); @@ -733,8 +775,8 @@ void SetTransfer(const string& key, const std::vector& types, } // namespace Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, const std::vector& types, - const std::vector& shapes) { + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -760,8 +802,8 @@ Status XlaCompiler::GetDeviceToHostShapes( } Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, const std::vector& types, - const std::vector& shapes) { + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index a70d2637e0b578ddb57dc990cd9550798e675e1d..a6747bbe72e161b2ece55697825cce0e71145a5c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -289,6 +289,14 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + // Compiles a single Op, given by an OpKernelContext, into an + // xla::Computation. Similar to CompileFunction but takes a single Op as + // input. + Status CompileSingleOp(const CompileOptions& options, string const& name, + OpKernelContext* ctx, + const std::vector& args, + CompilationResult* result); + // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. @@ -304,8 +312,8 @@ class XlaCompiler { // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, - const std::vector& types, - const std::vector& shapes); + gtl::ArraySlice types, + gtl::ArraySlice shapes); // Gets the shapes the device to host transfer associated with 'key'. Status GetDeviceToHostShapes(const string& key, @@ -314,8 +322,8 @@ class XlaCompiler { // Sets the shapes and types for the host to device transfer associated with // 'key'. Status SetHostToDeviceMetadata(const string& key, - const std::vector& types, - const std::vector& shapes); + gtl::ArraySlice types, + gtl::ArraySlice shapes); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index a18eeacd41808884fac9ec5d617cb0d274ea27d8..096dc7160bfc0a3a751f33e7d646471ebea56070 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -257,10 +258,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - StringPiece(status.error_message()).contains("depends on a parameter")) + str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - StringPiece(status.error_message()).contains("[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) << status.error_message(); } @@ -597,7 +598,8 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); } @@ -676,11 +678,12 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE( - StringPiece(status.error_message()).contains("Attr T is not found")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "Attr T is not found")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f048662953e20b2a612271e2daeef6e370c4822a..3b0b2f06ebae4af918cbe6fb8a384004c1858998 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -273,4 +274,20 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return Status::OK(); } +DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + if (dtype == DT_BFLOAT16) { + return DT_FLOAT; + } + return dtype; +} + +xla::ComputationDataHandle XlaHelpers::ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type) { + xla::PrimitiveType convert_to; + TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); + return builder->ConvertElementType(operand, convert_to); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 2a027db4c839c917f3a7acd27184792d157356bf..68ab93b64a5fa87ad99e0f44d84f6473fc8bbebd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -107,6 +107,18 @@ class XlaHelpers { const xla::ComputationDataHandle& on_value, const xla::ComputationDataHandle& off_value, xla::ComputationDataHandle* one_hot); + + // Certain DataTypes should use increased precision DataTypes when performing + // reductions. This function remaps a given DataType to a higher precision + // DataType if needed. + static DataType SumAccumulationType(const DataType& dtype); + + // A helper for creating a ConvertElementType xla op given a DataType rather + // than the xla::PrimitiveType. + static xla::ComputationDataHandle ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ff7453194af3a85bded86a5ce298f8779422dccb..e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,13 +51,13 @@ constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index cd13db4d300bb5bba21a734173b6afb9223539d8..751777222fcc7ec073958349aa2677d5b4e6757d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -654,18 +654,6 @@ tf_cc_test( # ----------------------------------------------------------------------------- -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. cc_header_only_library( name = "xla_headers_lib", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 24b58bec11bd8d8b5c79ac84c5f43c509644b51d..ea75ad32d5df7bbadd37e89de6144b264ab6d5d1 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 02356699a25e47be50eb15872df4c9c302fc289b..a299c2afd45aa6b785964b8a8e1400ddf54083a4 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", @@ -213,17 +214,3 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d15ccb0c28522c647617153aaa8e738d029dfaba..3f45167fcb77cd3085c9645fba0b2901329c4bb2 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -177,6 +177,22 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } +StatusOr> Client::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr data, + Execute(computation, arguments, execution_options, execution_profile)); + + const Shape* shape_with_output_layout = nullptr; + if (execution_options && execution_options->has_shape_with_output_layout()) { + shape_with_output_layout = &execution_options->shape_with_output_layout(); + } + return Transfer(*data, shape_with_output_layout); +} + StatusOr Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; @@ -231,6 +247,46 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } +StatusOr> Client::Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + ExecuteGraphRequest request; + *request.mutable_computation() = computation.proto(); + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { + *request.mutable_execution_options() = *execution_options; + } + for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; + *request.add_arguments() = argument->handle(); + } + + ExecuteResponse response; + VLOG(1) << "making execute request: " << request.ShortDebugString(); + Status s = stub_->ExecuteGraph(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN( + auto execution_stats, + ExecutionStatsAsString(computation, response.profile())); + VLOG(1) << execution_stats; + } + } + + return MakeUnique(stub_, response.output()); +} + StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteParallelRequest request; @@ -266,6 +322,42 @@ StatusOr>> Client::ExecuteParallel( return std::move(outputs); } +StatusOr>> Client::ExecuteParallel( + tensorflow::gtl::ArraySlice computations) { + ExecuteGraphParallelRequest request; + + for (const XlaComputationInstance& computation : computations) { + ExecuteGraphRequest single_request; + *single_request.mutable_computation() = computation.computation.proto(); + for (GlobalData* argument : computation.arguments) { + *single_request.add_arguments() = argument->handle(); + } + *single_request.mutable_execution_options() = computation.execution_options; + *request.add_requests() = single_request; + } + + ExecuteParallelResponse response; + VLOG(1) << "making execute-graph-parallel request: " + << request.ShortDebugString(); + tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector> outputs; + for (size_t i = 0; i < computations.size(); ++i) { + outputs.push_back( + MakeUnique(stub_, response.responses(i).output())); + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = response.responses(i).profile(); + } + } + + return std::move(outputs); +} + StatusOr> Client::GetDeviceHandles( int64 device_count) { if (device_count < 1) { @@ -342,6 +434,27 @@ StatusOr Client::GetComputationStats( return response.stats(); } +StatusOr Client::GetComputationStats( + const XlaComputation& computation, + const DebugOptions& debug_options) const { + ComputationGraphStatsRequest request; + + // TODO(b/74197823): Find a way to avoid the copy of the hlo proto. + *request.mutable_computation() = computation.proto(); + *request.mutable_debug_options() = debug_options; + ComputationStatsResponse response; + + VLOG(1) << "making computation graph stats request"; + Status s = stub_->GetComputationGraphStats(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + CHECK(response.has_stats()); + return response.stats(); +} + StatusOr> Client::GetComputationShape( const Computation& computation) { GetComputationShapeRequest request; @@ -359,6 +472,12 @@ StatusOr> Client::GetComputationShape( return WrapUnique(response.release_program_shape()); } +StatusOr> Client::GetComputationShape( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); + return MakeUnique(result); +} + StatusOr Client::GetShape(const GlobalData& data) { GetShapeRequest request; *request.mutable_data() = data.handle(); @@ -397,6 +516,28 @@ StatusOr Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } +StatusOr Client::ExecutionStatsAsString( + const XlaComputation& computation, const ExecutionProfile& profile) { + TF_ASSIGN_OR_RETURN( + auto computation_stats, + GetComputationStats(computation, + legacy_flags::GetDebugOptionsFromFlags())); + int64 total_flops = + computation_stats.flop_count() + computation_stats.transcendental_count(); + if (profile.compute_time_ns() > 0) { + int64 nanoseconds = profile.compute_time_ns(); + int64 cycle_count = profile.compute_cycle_count(); + double gflops = total_flops / nanoseconds; + return tensorflow::strings::StrCat( + "[Execution Statistics] flop count: ", computation_stats.flop_count(), + ", transcendental count: ", computation_stats.transcendental_count(), + ", compute execution time: ", nanoseconds, " nsec", + ", compute cycles: ", cycle_count, ", performance: ", gflops, + "gflop/s"); + } + return string("[Execution Statistics] not available."); +} + StatusOr Client::CreateChannelHandle() { CreateChannelHandleRequest request; CreateChannelHandleResponse response; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index c28380b689c7a0e16bf0bcbf15003f4aa15e42a7..05d707dab1533f44ce827157e888720e218d4c9c 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" @@ -57,6 +58,21 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global + // data that was produced from the execution. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // A struct to represent a computation instance to be executed. // * If execution_options.device_handles is not empty, the computation is // executed on the devices associated with the handles by partitioning the @@ -83,6 +99,36 @@ class Client { StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); + // A struct to represent a computation instance to be executed. + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + struct XlaComputationInstance { + const XlaComputation& computation; + std::vector arguments; + ExecutionOptions execution_options; + ExecutionProfile* execution_profile; + + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} + }; + + // Executes a list XlaComputationInstances and returns global data produced + // from each computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr>> ExecuteParallel( + tensorflow::gtl::ArraySlice computations); + // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations // (see ExecuteParallel) or to transfer data (see TransferToServer or @@ -137,6 +183,17 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and transfers the result + // to the client as a literal. Parameters are defined the same as for + // Execute() and Transfer(). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); @@ -148,6 +205,13 @@ class Client { StatusOr GetComputationStats( const Computation& computation, const DebugOptions& debug_options) const; + // Retrieves the statistics of the given computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr GetComputationStats( + const XlaComputation& computation, + const DebugOptions& debug_options) const; + // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. StatusOr GetShape(const GlobalData& data); @@ -157,6 +221,13 @@ class Client { StatusOr> GetComputationShape( const Computation& computation); + // As above, but returns the shape of the provided computation (parameter + // types/names and return type). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> GetComputationShape( + const XlaComputation& computation); + // Creates a channel handle that can be used to transfer data between // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); @@ -170,6 +241,8 @@ class Client { // ExecutionProfile returned from an execution of the computation. StatusOr ExecutionStatsAsString(const Computation& computation, const ExecutionProfile& profile); + StatusOr ExecutionStatsAsString(const XlaComputation& computation, + const ExecutionProfile& profile); ServiceInterface* stub_; // Stub that this client is connected on. diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 39d02f0863f78d4094f2cc4805f534713fb7e929..4d3b0ee0d6e9ba82cfa09af0fbff0ae1efa0ac64 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -253,26 +253,6 @@ StatusOr ComputationBuilder::GetProgramShape() { return std::move(*response.mutable_program_shape()); } -ComputationDataHandle ComputationBuilder::CheckShape( - const ComputationDataHandle& operand, const Shape& expected_shape) { - std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); - CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) - << "want " << ShapeUtil::HumanString(expected_shape) << " got " - << ShapeUtil::HumanString(*actual_shape); - return operand; -} - -void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { - std::unique_ptr lhs_shape = GetShape(lhs).ConsumeValueOrDie(); - std::unique_ptr rhs_shape = GetShape(rhs).ConsumeValueOrDie(); - VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " - << ShapeUtil::HumanString(*rhs_shape); - CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) - << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " - << ShapeUtil::HumanString(*rhs_shape); -} - ComputationDataHandle ComputationBuilder::Slice( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice start_indices, diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 2141ebc2065a1a80d2fe820a7b6fe15434c89e28..019c6f3afb5d57bfe453988ded19120a4483cf36 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -104,15 +104,6 @@ class ComputationBuilder { // Retrieves the (inferred) result for the current computation's shape. StatusOr GetProgramShape(); - // Checks that the operand has the given expected shape. Returns the operand - // if yes, fails with a CHECK error if no. - ComputationDataHandle CheckShape(const ComputationDataHandle& operand, - const Shape& expected_shape); - - // Checks that the lhs and rhs results have the same shape. - void CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - // Enqueues a constant with the value of the given literal onto the // computation. ComputationDataHandle ConstantLiteral(const Literal& literal); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 804e34f5e75ce2d153ac7627b94a543fda88e810..6e3c5cb484b8f1ef053fa287a4d462aeb886e530 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -76,4 +76,35 @@ ExecutableBuildOptions::generate_hlo_graph() const { return generate_hlo_graph_; } +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_optimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { + return dump_optimized_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { + return dump_per_pass_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { + hlo_profile_ = enabled; + return *this; +} + +tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { + return hlo_profile_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 3a52dbac9adb155ad9a7d91a8102707f70fe2fbf..11f10983606fe02b1edb11a260edde8e5f9a726f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,15 +58,36 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_generate_hlo_graph(string regex); const tensorflow::gtl::optional& generate_hlo_graph() const; + // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs + // to (as in DebugOptions). + ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + + // If true, specifies that we should record an HLO profile during execution + // and log it after execution (as in DebugOptions). If nullopt the default is + // used. + ExecutableBuildOptions& set_hlo_profile(bool enabled); + tensorflow::gtl::optional hlo_profile() const; + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: + tensorflow::gtl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; + tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index fca2bf2688cd21b44f099da3bae3b890cbb069ab..f4673a8204f27e93441c73f6dcc9130d96cfcebc 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -24,6 +24,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -48,17 +50,3 @@ cc_library( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 24048a1e5a782661ba577ba50e3b5b2914f17c0a..63df449e0b3bdd642d548319dd7d621ca2f59b1d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -26,6 +26,7 @@ limitations under the License. namespace xla { namespace { + using InstructionGenerator = ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, const ComputationDataHandle&); @@ -47,6 +48,27 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type, generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } + +using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); + +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator) { + std::unique_ptr b; + if (type == PRED) { + b = builder->CreateSubBuilder(name); + } else { + b = builder->CreateSubBuilder( + tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + } + + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + generator(b.get(), lhs, rhs); + return b->BuildAndNoteError(); +} + } // namespace Computation CreateScalarAddComputation(PrimitiveType type, @@ -60,7 +82,7 @@ Computation CreateScalarAddComputation(PrimitiveType type, Computation CreateScalarMultiplyComputation(PrimitiveType type, ComputationBuilder* builder) { return CreateScalarComputation( - "add", type, builder, + "mul", type, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); } @@ -114,4 +136,75 @@ StatusOr Any(const ComputationDataHandle& predicates, return builder->Reduce(predicates, f, logical_or, all_dimensions); } +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "add", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Add(lhs, rhs); + }); +} + +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "mul", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Mul(lhs, rhs); + }); +} + +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Ge(lhs, rhs); + }); +} + +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "max", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Max(lhs, rhs); + }); +} + +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "min", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Min(lhs, rhs); + }); +} + +XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "and", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->And(lhs, rhs); + }); +} + +XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "or", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Or(lhs, rhs); + }); +} + +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { + auto f = builder->ConstantR0(false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return builder->Reduce(predicates, f, logical_or, all_dimensions); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index ae89784bc227d837cf15f0a89687dd00dccc2745..f4d3fc801590fedbb84ed3d6283e62f47c56d5c7 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -56,6 +58,48 @@ Computation CreateScalarOrComputation(ComputationBuilder* builder); StatusOr Any(const ComputationDataHandle& predicates, ComputationBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar add computation and returns it. +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar multiply computation and returns it. +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar ge computation and returns it. +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar max computation and returns it. +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar min computation and returns it. +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical AND computation and returns it. +XlaComputation CreateScalarAndComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical OR computation and returns it. +XlaComputation CreateScalarOrComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 91396f055fe4a3ecbd436139be9470e2a35e1c63..30594243dcf51d2b5312b9dcb2bea7d0cd78524d 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -265,6 +265,24 @@ StatusOr> LocalClient::Compile( updated_options)); } +StatusOr> LocalClient::Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options) { + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation, argument_layouts, updated_options)); + return WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); +} + StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index de0ed13c43f87966c272102b2e9af9ff3be63aea..98ee7c62c94be7c618cedd3dc12ecbfc812ee180 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -123,6 +123,15 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Build and return a LocalExecutable object. The executable is compiled using + // the given XlaComputation, argument layouts and options. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index b912889e2627aa01e5a7441e71e6bf002916ba5e..b1dba168565cca86cba0403604736fecd00d6f29 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -25,12 +25,25 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:lib", + ], +) + # TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], deps = [ + ":xla_computation", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -38,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", @@ -56,22 +70,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 6328a4f350fc70efaa96102f8202fb00b88b51f2..2d587cc3b9c51d5bd81652d17b23d4ad05c84dd3 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include +#include #include #include @@ -43,6 +45,7 @@ int64 GetUniqueId() { bool CanBeRoot(HloOpcode opcode) { switch (opcode) { case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kOutfeed: case HloOpcode::kTrace: return false; @@ -51,25 +54,35 @@ bool CanBeRoot(HloOpcode opcode) { } } -void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) { - instr->set_opcode(HloOpcodeString(opcode)); +StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); + operand_shapes.push_back(shape); + } + return operand_shapes; } } // namespace -StatusOr> XlaBuilder::GetShape(const XlaOp& op) const { +StatusOr XlaBuilder::GetShape(const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); - return MakeUnique(instr->shape()); + return instr->shape(); } StatusOr XlaOp::GetShape() const { - TF_RET_CHECK(builder_ != nullptr); - TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this)); - return *shape; + if (builder_ == nullptr) { + return InvalidArgument( + "cannot GetShape for an invalid XlaOp with handle %lld", handle()); + } + return builder_->GetShape(*this); } XlaBuilder::XlaBuilder(const string& computation_name) - : name_(computation_name) {} + : name_(computation_name), unique_id_(GetUniqueId()) {} XlaBuilder::~XlaBuilder() {} @@ -85,39 +98,47 @@ void XlaBuilder::NoteError(const Status& error) { } } -StatusOr XlaBuilder::Build() { +XlaOp XlaBuilder::NoteErrorOrReturn( + const std::function()>& op_creator) { if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); + return {}; + } + auto op = op_creator(); + if (!op.ok()) { + NoteError(op.status()); + return {}; } + return op.ConsumeValueOrDie(); +} - HloComputationProto entry; - ProgramShape* program_shape = entry.mutable_program_shape(); +StatusOr XlaBuilder::GetProgramShape(int64* root_id) { + TF_RETURN_IF_ERROR(first_error_); - entry.set_name(name_); + TF_RET_CHECK(root_id != nullptr); + ProgramShape program_shape; // Not all instructions can be roots. Walk backwards from the last added // instruction until a valid root is found. - for (int64 i = instructions_.size() - 1; i >= 0; i--) { + int64 index = instructions_.size() - 1; + for (; index >= 0; index--) { TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[i].opcode())); + StringToHloOpcode(instructions_[index].opcode())); if (CanBeRoot(opcode)) { - entry.set_root_name(instructions_[i].name()); - *program_shape->mutable_result() = instructions_[i].shape(); break; } } - if (entry.root_name().empty()) { + if (index < 0) { return FailedPrecondition("no root instruction was found"); } + *root_id = instructions_[index].id(); + *program_shape.mutable_result() = instructions_[index].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. const int64 param_count = parameter_numbers_.size(); for (int64 i = 0; i < param_count; i++) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); + program_shape.add_parameters(); + program_shape.add_parameter_names(); } for (const HloInstructionProto& instr : instructions_) { // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So @@ -127,93 +148,275 @@ StatusOr XlaBuilder::Build() { const int64 index = instr.parameter_number(); TF_RET_CHECK(index >= 0 && index < param_count) << "invalid parameter number: " << index; - *program_shape->mutable_parameters(index) = instr.shape(); - *program_shape->mutable_parameter_names(index) = instr.name(); + *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameter_names(index) = instr.name(); } } + return program_shape; +} + +StatusOr XlaBuilder::GetProgramShape() { + int64 root_id; + return GetProgramShape(&root_id); +} + +XlaComputation XlaBuilder::BuildAndNoteError() { + DCHECK(parent_builder_ != nullptr); + auto build_status = Build(); + if (!build_status.ok()) { + parent_builder_->NoteError( + AddStatus(build_status.status(), + tensorflow::strings::StrCat("error from: ", name_))); + return {}; + } + return build_status.ConsumeValueOrDie(); +} + +StatusOr XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + + { + int64 root_id; + ProgramShape program_shape; + TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id)); + entry.mutable_program_shape()->Swap(&program_shape); + entry.set_root_id(root_id); + } for (auto& instruction : instructions_) { entry.add_instructions()->Swap(&instruction); } - const int64 id = GetUniqueId(); - entry.set_id(id); - XlaComputation computation(id); + entry.set_id(unique_id_); + entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. + XlaComputation computation(entry.id()); HloModuleProto* module = computation.mutable_proto(); module->set_name(entry.name()); + module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); *module->mutable_program_shape() = entry.program_shape(); for (auto& e : embedded_) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); + // Clear data held by this builder. + this->instructions_.clear(); + this->embedded_.clear(); + this->parameter_numbers_.clear(); + return std::move(computation); } -XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto op = [&]() -> StatusOr { +StatusOr XlaBuilder::InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + for (int64 dim : broadcast_dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); +} + +StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + + CHECK(ShapeUtil::IsScalar(operand_shape) || + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = + ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); + + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand_shape)) { + return InDimBroadcast(broadcast_shape, operand, {}); + } + + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand_shape.dimensions(i)); + } else { + TF_RET_CHECK(operand_shape.dimensions(i) == 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand shape: " + << operand_shape << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return InDimBroadcast(broadcast_shape, reshaped_operand, + broadcast_dimensions); +} + +XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferUnaryOpShape(unop, operand_shape)); + return AddInstruction(std::move(instr), unop, {operand}); + }); +} + +XlaOp XlaBuilder::BinaryOp( + HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - SetOpcode(&instr, HloOpcode::kAdd); - TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs)); - TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferBinaryOpShape( - HloOpcode::kAdd, lhs_instr->shape(), - rhs_instr->shape(), broadcast_dimensions)); - instr.add_operand_names(lhs_instr->name()); - instr.add_operand_names(rhs_instr->name()); - return AddInstruction(std::move(instr)); - }; - return NoteErrorOrReturn(op()); + binop, lhs_shape, rhs_shape, broadcast_dimensions)); + + const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); + const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + + if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { + const bool should_broadcast_lhs = lhs_rank < rhs_rank; + XlaOp from = should_broadcast_lhs ? lhs : rhs; + const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; + + std::vector to_size; + for (int64 size : instr.shape().dimensions()) { + to_size.push_back(size); + } + for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); + from_dim++) { + int64 to_dim = broadcast_dimensions[from_dim]; + to_size[to_dim] = from_shape.dimensions(from_dim); + } + + const Shape& broadcasted_shape = + ShapeUtil::MakeShape(from_shape.element_type(), to_size); + TF_ASSIGN_OR_RETURN( + XlaOp broadcasted_operand, + InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); + + updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; + updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; + } + + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), updated_lhs)); + } + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), updated_rhs)); + } + + return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + }); +} + +XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferTernaryOpShape( + triop, lhs_shape, rhs_shape, ehs_shape)); + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + XlaOp updated_ehs = ehs; + if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(lhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + // lhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), lhs)); + } + if (!ShapeUtil::IsTuple(rhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + // rhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), rhs)); + } + if (!ShapeUtil::IsTuple(ehs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + // ehs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_ehs, + AddBroadcastSequence(instr.shape(), ehs)); + } + } + return AddInstruction(std::move(instr), triop, + {updated_lhs, updated_rhs, updated_ehs}); + }); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { - HloInstructionProto instr; - SetOpcode(&instr, HloOpcode::kConstant); - *instr.mutable_shape() = literal.shape(); - *instr.mutable_literal() = literal.ToProto(); - return AddInstruction(std::move(instr)); + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kConstant); + }); } XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice operands) { - auto op = [&]() -> StatusOr { + return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - SetOpcode(&instr, HloOpcode::kCall); - std::vector operand_shapes; - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand)); - operand_shapes.push_back(&input->shape()); - } - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferCallShape( - operand_shapes, - /*to_apply=*/computation.GetProgramShape())); + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCallShape(operand_shape_ptrs, + /*to_apply=*/called_program_shape)); - // Add input operands. - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand)); - instr.add_operand_names(operand_instr->name()); - } + AddCalledComputation(computation, &instr); - // Add called computation. - *instr.add_called_computation_names() = computation.proto().name(); - for (const HloComputationProto& e : computation.proto().computations()) { - embedded_.insert({e.id(), e}); - } - - return AddInstruction(std::move(instr)); - }; - return NoteErrorOrReturn(op()); + return AddInstruction(std::move(instr), HloOpcode::kCall, operands); + }); } XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - auto op = [&]() -> StatusOr { + return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - SetOpcode(&instr, HloOpcode::kParameter); if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) { return InvalidArgument("parameter %lld already registered", parameter_number); @@ -222,27 +425,895 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, instr.set_parameter_number(parameter_number); instr.set_name(name); *instr.mutable_shape() = shape; - return AddInstruction(std::move(instr)); - }; - return NoteErrorOrReturn(op()); + return AddInstruction(std::move(instr), HloOpcode::kParameter); + }); +} + +XlaOp XlaBuilder::Broadcast( + const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + const Shape& shape, + ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); + + // The client-level broadcast op just appends dimensions on the left (adds + // lowest numbered dimensions). The HLO broadcast instruction is more + // flexible and can add new dimensions anywhere. The instruction's + // dimensions field maps operand dimensions to dimensions in the broadcast + // output, so to append dimensions on the left the instruction's dimensions + // should just be the n highest dimension numbers of the output shape where + // n is the number of input dimensions. + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + std::vector dimensions(operand_rank); + for (int i = 0; i < operand_rank; ++i) { + dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + } + return InDimBroadcast(shape, operand, dimensions); + }); +} + +StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); +} + +XlaOp XlaBuilder::Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferSliceShape(operand_shape, start_indices, + limit_indices, strides)); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + }); +} + +XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shape, slice_sizes)); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, + {operand, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, start_indices_shape)); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + }); +} + +XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferReshapeShape( + operand_shape, dimensions, new_sizes)); + XlaOp transposed = IsIdentityPermutation(dimensions) + ? operand + : Transpose(operand, dimensions); + return Reshape(shape, transposed); + }); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + std::vector dimensions(shape.dimensions_size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); + }); +} + +XlaOp XlaBuilder::Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return UnimplementedOp(); +} + +void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); +} + +XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false) { + return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); +} + +XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kTuple, operand_shape_ptrs)); + return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); + }); +} + +XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); + if (!ShapeUtil::IsTuple(tuple_shape)) { + return InvalidArgument( + "Operand to GetTupleElement() is not a tuple; got %s", + ShapeUtil::HumanString(tuple_shape).c_str()); + } + *instr.mutable_shape() = + ShapeUtil::GetTupleElementShape(tuple_shape, index); + + instr.set_tuple_index(index); + + return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, + {tuple_data}); + }); +} + +XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape.dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); + }); +} + +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, + dimension_numbers)); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); + }); +} + +XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { + return UnimplementedOp(); +} + +void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, + int64 cost_estimate_ns, const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Complex( + const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); +} + +XlaOp XlaBuilder::Conj(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + +XlaOp XlaBuilder::Not(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNot, operand); +} + +XlaOp XlaBuilder::ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::Abs(const XlaOp& operand) { + return UnaryOp(HloOpcode::kAbs, operand); +} + +XlaOp XlaBuilder::Atan2( + const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); +} + +XlaOp XlaBuilder::Exp(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExp, operand); +} + +XlaOp XlaBuilder::Floor(const XlaOp& operand) { + return UnaryOp(HloOpcode::kFloor, operand); +} + +XlaOp XlaBuilder::Ceil(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCeil, operand); +} + +XlaOp XlaBuilder::Round(const XlaOp& operand) { + return UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} + +XlaOp XlaBuilder::Log(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog, operand); +} + +XlaOp XlaBuilder::Sign(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSign, operand); +} + +XlaOp XlaBuilder::Cos(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCos, operand); +} + +XlaOp XlaBuilder::Sin(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSin, operand); +} + +XlaOp XlaBuilder::Tanh(const XlaOp& operand) { + return UnaryOp(HloOpcode::kTanh, operand); +} + +XlaOp XlaBuilder::Real(const XlaOp& operand) { + return UnaryOp(HloOpcode::kReal, operand); +} + +XlaOp XlaBuilder::Imag(const XlaOp& operand) { + return UnaryOp(HloOpcode::kImag, operand); +} + +XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { + return UnaryOp(HloOpcode::kIsFinite, operand); +} + +XlaOp XlaBuilder::Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferTransposeShape(operand_shape, permutation)); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + }); +} + +XlaOp XlaBuilder::Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Sort(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSort, operand); +} + +XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + }); +} + +XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Neg(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNegate, operand); +} + +XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, + const XlaOp& max) { + return TernaryOp(HloOpcode::kClamp, min, operand, max); +} + +XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Check the number of parameters per RNG distribution. + switch (distribution) { + case RandomDistribution::RNG_NORMAL: + case RandomDistribution::RNG_UNIFORM: + if (parameters.size() != 2) { + return InvalidArgument( + "RNG distribution (%s) expects 2 parameters, but got %ld", + RandomDistribution_Name(distribution).c_str(), parameters.size()); + } + break; + default: + LOG(FATAL) << "unhandled distribution " << distribution; + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + *instr.mutable_shape() = shape; + + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + }); +} + +XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); +} + +XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); +} + +XlaOp XlaBuilder::While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Infer shape. + TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, + condition.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferWhileShape(condition_program_shape, + body_program_shape, init_shape)); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + }); +} + +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reduce( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape, init_shape, dimensions_to_reduce, + called_program_shape)); + + for (int64 dim : dimensions_to_reduce) { + instr.add_dimensions(dim); + } + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kReduce, + {operand, init_value}); + }); +} + +XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return UnimplementedOp(); } -XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) { +void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Send instruction produces a tuple of {aliased operand, U32 context}. + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN( + XlaOp send, + AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + send_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Recv instruction produces a tuple of {receive buffer, U32 context}. + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, + AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = shape; + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +StatusOr XlaBuilder::IsConstant(const XlaOp& operand, + int64 num_parameters) { + return Unimplemented("IsConstant is not implemented."); +} + +StatusOr> XlaBuilder::ComputeConstant( + const XlaOp& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { + return Unimplemented("ComputeConstant is not implemented"); +} + +std::unique_ptr XlaBuilder::CreateSubBuilder( + const string& computation_name) { + auto sub_builder = MakeUnique(computation_name); + sub_builder->parent_builder_ = this; + sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; + return sub_builder; +} + +Status XlaBuilder::SetReturnValue(const XlaOp& operand) { + return Unimplemented("SetReturnValue is not implemented."); +} + +/* static */ ConvolutionDimensionNumbers +XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_kernel_output_feature_dimension( + kConvKernelOutputDimension); + dimension_numbers.set_kernel_input_feature_dimension( + kConvKernelInputDimension); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(i + 2); + dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); + } + return dimension_numbers; +} + +/* static */ Status XlaBuilder::Validate( + const ConvolutionDimensionNumbers& dnum) { + if (dnum.input_spatial_dimensions_size() < 2) { + return FailedPrecondition("input spacial dimension < 2: %d", + dnum.input_spatial_dimensions_size()); + } + if (dnum.kernel_spatial_dimensions_size() < 2) { + return FailedPrecondition("kernel spacial dimension < 2: %d", + dnum.kernel_spatial_dimensions_size()); + } + if (dnum.output_spatial_dimensions_size() < 2) { + return FailedPrecondition("output spacial dimension < 2: %d", + dnum.output_spatial_dimensions_size()); + } + + if (std::set( + {dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the input are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); + } + if (std::set({dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), + dnum.kernel_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); + } + if (std::set({dnum.output_batch_dimension(), + dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), + dnum.output_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.output_batch_dimension(), dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); + } + return Status::OK(); +} + +StatusOr XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + TF_RETURN_IF_ERROR(first_error_); + const int64 handle = instructions_.size(); + instr.set_id(handle); + instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode(), ".", handle)); + instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle)); } else { // Append the handle to make sure the name is unique. - instr.set_name(StrCat(instr.name(), ".", handle)); + instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle)); + } + for (const auto& operand : operands) { + if (operand.builder_ == nullptr) { + return InvalidArgument("invalid XlaOp with handle %lld", + operand.handle()); + } + if (operand.builder_ != this) { + return InvalidArgument("Do not add XlaOp from builder %s to builder %s", + operand.builder_->name().c_str(), + this->name().c_str()); + } + instr.add_operand_ids(operand.handle()); } + + *instr.mutable_metadata() = metadata_; + if (sharding_) { + *instr.mutable_sharding() = *sharding_; + } + instructions_.push_back(instr); XlaOp op(handle, this); return op; } +void XlaBuilder::AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr) { + instr->add_called_computation_ids(computation.proto().entry_computation_id()); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } +} + StatusOr XlaBuilder::LookUpInstruction( const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + if (op.builder_ != this) { + return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + } + TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); @@ -250,4 +1321,9 @@ StatusOr XlaBuilder::LookUpInstruction( return &instructions_[op.handle()]; } +XlaOp XlaBuilder::UnimplementedOp() { + NoteError(Unimplemented("Op not implemented")); + return {}; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 7632bd289d792ef487fb667de3cea335e06778bf..0673b86646eeecae45b1076baf0002ed94242846 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -24,8 +24,11 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -49,10 +52,11 @@ class XlaBuilder; // TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: + XlaOp() : handle_(0), builder_(nullptr) {} + StatusOr GetShape() const; private: - XlaOp() : handle_(0), builder_(nullptr) {} XlaOp(int64 handle, XlaBuilder* builder) : handle_(handle), builder_(builder) {} @@ -63,38 +67,6 @@ class XlaOp { XlaBuilder* builder_; // Not owned. }; -// The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. -class XlaComputation { - public: - XlaComputation(const XlaComputation&) = delete; - XlaComputation& operator=(const XlaComputation&) = delete; - - XlaComputation(XlaComputation&& from) { *this = std::move(from); } - - XlaComputation& operator=(XlaComputation&& from) { - proto_ = std::move(from.proto()); - unique_id_ = from.unique_id_; - return *this; - } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - const ProgramShape& GetProgramShape() const { return proto_.program_shape(); } - - const HloModuleProto& proto() const { return proto_; } - - private: - // Creates a null Computation. - XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} - HloModuleProto* mutable_proto() { return &proto_; } - friend class XlaBuilder; - - int64 unique_id_; - HloModuleProto proto_; -}; - // A convenient interface for building up computations. // // Thread-compatible. @@ -113,6 +85,29 @@ class XlaBuilder { // Returns the computation name. const string& name() const { return name_; } + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is // called (which is the default). @@ -120,14 +115,6 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } - // Enqueues an add instruction onto the computation. - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a call instruction onto the computation. - XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); - // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -155,16 +142,669 @@ class XlaBuilder { // corresponding native type yet. template XlaOp ConstantR0(NativeT value); + template + XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template + XlaOp ConstantR2( + std::initializer_list> values); + template + XlaOp ConstantFromArrayWithLayout(const Array& values, + const Layout& layout); + template + XlaOp ConstantFromArray(const Array& values); + template + XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + XlaOp ConstantR2FromArray2D(const Array2D& values); + template + XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + XlaOp ConstantR3FromArray3D(const Array3D& values); + template + XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + template + XlaOp ConstantR4FromArray4D(const Array4D& values); - // Returns the shape of the given op. - StatusOr> GetShape(const XlaOp& op) const; + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value across all replicas. All replicas + // supply one input to the sum and all replicas receive the resulting sum. + XlaOp CrossReplicaSum(const XlaOp& operand); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + XlaOp SqrtF32(const XlaOp& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + XlaOp SquareF32(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. + XlaOp ReciprocalF32(const XlaOp& operand); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + XlaOp Sort(const XlaOp& operand); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on parameters with index greater than or equal to + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const XlaOp& operand, int64 num_parameters = 0); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + + // Computes the value of a constant indicated by a XlaOp using a non-optimized + // interpreter on the host. + // + // The operand must represent a constant value, which in this case + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // parameter list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. + // + // `IsConstant` can be used to test whether a computation is a compile-time + // constant without evaluation it. `ComputeConstant` only succeeds for + // computations where `IsConstant` returns true. + // + // This functionality can be useful when translating a computation + // into XLA where something that looked dynamic is required by + // XLA to be specified as a constant. E.g. the source + // computation (outside of XLA) may include a dynamic + // computation of the shape of something and ComputeConstant lets + // you determine what the value of that computation is in the case + // where the value can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation + // will be stored using that layout. + StatusOr> ComputeConstant( + const XlaOp& operand, const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Modifies the computation being built so that executions of it will return + // the value associated with operand, rather than the last expression enqueued + // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will + // not have any effect unless SetReturnValue is called again. + Status SetReturnValue(const XlaOp& operand); // Builds the computation with the requested operations, or returns a non-ok // status. StatusOr Build(); + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + private: - XlaOp AddInstruction(HloInstructionProto&& instr); + StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands = {}); + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); // Notes that the error occurred by: // * storing it internally and capturing a backtrace if it's the first error @@ -172,17 +812,49 @@ class XlaBuilder { // * dying if die_immediately_on_error_ is true void NoteError(const Status& error); - XlaOp NoteErrorOrReturn(StatusOr&& op) { - if (!op.ok()) { - NoteError(op.status()); - return XlaOp(); - } - return op.ConsumeValueOrDie(); - } + XlaOp NoteErrorOrReturn(const std::function()>& op_creator); + + // Helper method that creates an empty op and notes error. + XlaOp UnimplementedOp(); StatusOr LookUpInstruction(const XlaOp& op) const; - string name_; // Name to use for the built computation. + // Internal helper method that does the building for an arbitrary unary op. + XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs); + + XlaOp RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape); + + StatusOr InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape for the current + // computation and fills the root_id in the pointer. + StatusOr GetProgramShape(int64* root_id); + + string name_; // Name to use for the built computation. + int64 unique_id_; // The unique id for the built computation. // The first error encountered while building the computation. // This is OK until the first error is encountered. @@ -202,8 +874,19 @@ class XlaBuilder { // The unique parameter numbers. tensorflow::gtl::FlatSet parameter_numbers_; + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; + // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; }; template @@ -211,6 +894,76 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) { return ConstantLiteral(*Literal::CreateR0(value)); } +template +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantLiteral(*Literal::CreateR2(values)); +} + +template +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, + const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantFromArray(const Array& values) { + return ConstantLiteral(*Literal::CreateFromArray(values)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { + return ConstantFromArray(values); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { + return ConstantFromArray(values); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index a400e4e78b044ae633a0135b0011d5267eacc115..ce984564d016ce65fa6c932f3cda290cc0d75a4a 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -39,7 +40,8 @@ class XlaBuilderTest : public ::testing::Test { TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, - HloModule::CreateModuleConfigFromProto(proto)); + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -57,16 +59,16 @@ TEST_F(XlaBuilderTest, OnePlusTwo) { EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); } -TEST_F(XlaBuilderTest, ParamPlusConstant) { +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); b.Add(x, b.ConstantR0(1.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant())); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); } -TEST_F(XlaBuilderTest, ParamPlusParam) { +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); @@ -79,7 +81,7 @@ TEST_F(XlaBuilderTest, ParamPlusParam) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1))); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); } TEST_F(XlaBuilderTest, XPlusX) { @@ -133,5 +135,103 @@ TEST_F(XlaBuilderTest, Call) { op::Call(op::Constant(), op::Constant()))); } +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + b.Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + builder.Add(p0, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Do not add XlaOp from builder b1 to builder main")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +TEST_F(XlaBuilderTest, Xor) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Xor(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + LOG(ERROR) << module->ToString(); + EXPECT_THAT(root, + op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), + op::And(op::Parameter(0), op::Not(op::Parameter(1))))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6752c601026518825c7994f6b6fa20d20f34f24 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" + +#include + +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +StatusOr XlaComputation::GetProgramShape() const { + TF_RET_CHECK(proto_.has_program_shape()); + return proto_.program_shape(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h new file mode 100644 index 0000000000000000000000000000000000000000..2a3c6952667a434b68ca0c5e4e9874397da173d3 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +// +// TODO(b/74197823): Replace xla::Computation with this one. +class XlaComputation { + public: + XlaComputation() : unique_id_(-1) {} + + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + StatusOr GetProgramShape() const; + + const HloModuleProto& proto() const { return proto_; } + + private: + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 0a9725db0a4fcf963cadcacf2cbc1d95d2c7239d..89353448e29ec3d97275dac288e23aa8e96e31b2 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -75,17 +75,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index a3b4286f4c12bf39a44c63dd6e7d303a46a418c3..7b6ae311c1099dccb8dceb2f49743c1b185cd5ab 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 0a24db046a390eb447bc3518476c3dd9897d973c..13675b7d0074592043b7e12de0aad948a3e9848f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -929,7 +929,7 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, case U64: return StrCat(Get(multi_index, shape_index)); case F16: - return StrCat(Get(multi_index, shape_index)); + return StrCat(static_cast(Get(multi_index, shape_index))); case F32: return StrCat(Get(multi_index, shape_index)); case BF16: @@ -979,7 +979,8 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, return StrCat( GetSparseElement(sparse_element_number, shape_index)); case F16: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); case F32: return StrCat( GetSparseElement(sparse_element_number, shape_index)); @@ -1384,8 +1385,9 @@ void Literal::EachCellAsString( } namespace { -template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const Literal& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1395,11 +1397,18 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); + dest_data[i] = converter(src_data[i]); } return result_literal; } +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + template std::unique_ptr ConvertToC64(const Literal& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); @@ -1462,6 +1471,9 @@ StatusOr> ConvertIfDestTypeMatches( StatusOr> Literal::Convert( PrimitiveType primitive_dest_type) const { TF_RET_CHECK(ShapeUtil::IsArray(shape())); + if (shape().element_type() == primitive_dest_type) { + return CloneToUnique(); + } switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ @@ -1488,8 +1500,16 @@ StatusOr> Literal::Convert( } StatusOr> Literal::ConvertToShape( - const Shape& dest_shape) const { + const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } return Convert(dest_shape.element_type()); } std::vector elements; diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e24f5285d9a14cf26216e4a16c6d1e516afc413f..a96a76fbb4e1a46e225d33b715f073c05fe6275a 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -340,8 +340,14 @@ class Literal { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. StatusOr> ConvertToShape( - const Shape& dest_shape) const; + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 04e45f00491b0bef94f3c0af1c875b2d007194fd..7627762074b6132655c58690a7fffbaf2717e279 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -1702,7 +1702,7 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(half{2.0})); + tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( Literal::CreateSparse( dimensions, indices, diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index e2972f06016ab3555c4fc0cc4616993fe6764b1e..0517a5502e686def4ffea59f929aef225186a8aa 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -72,15 +72,3 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla/service:cpu_plugin", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b21ab3044fae7136071f50bdba6e74b799a309d5..2bacc6a9142971f6d14b3929fb1a69e2a40052e2 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -521,6 +521,17 @@ ComputationDataHandle LocalComputationBuilder::Conditional( false_computation.computation()); } +StatusOr LocalComputationBuilder::IsConstant( + const ComputationDataHandle& operand, int64 num_parameters) { + return builder_.IsConstant(operand, num_parameters); +} + +StatusOr> LocalComputationBuilder::ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { + return builder_.ComputeConstant(operand, output_layout, parameters); +} + #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ return builder_.method_name args; \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a7375c8965e9041226ffee08dab6ffafa25312af..31046e60f11af9cc89ddec4c5fd16babfc8eb231 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -268,6 +268,13 @@ class LocalComputationBuilder { const ComputationDataHandle& false_operand, const LocalComputation& false_computation); + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters); + + StatusOr > ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters); + #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b5354131c94930b75ea66036ddb61ecd3993414f..ac792e8189bda9eda472e7d282db86ac988c57b9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -141,6 +141,33 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) { return true; } +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, + const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } + if (!PyString_Check(attr)) { + string message = tensorflow::strings::Printf("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr).c_str()); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + } } %} @@ -155,7 +182,7 @@ tensorflow::ImportNumpy(); %typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { const int64 handle = numpy::PyIntOrPyLongToLong($input); if (handle == -1 && PyErr_Occurred()) { - return NULL; + SWIG_fail; } temp.set_handle(handle); $1 = &temp; @@ -174,7 +201,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -184,7 +211,7 @@ tensorflow::ImportNumpy(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -197,7 +224,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -206,7 +233,16 @@ tensorflow::ImportNumpy(); $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; } } @@ -214,8 +250,9 @@ tensorflow::ImportNumpy(); if (!$1.ok()) { PyErr_SetString( PyExc_RuntimeError, $1.ToString().c_str()); - return NULL; + SWIG_fail; } + Py_INCREF(Py_None); $result = Py_None; } @@ -225,7 +262,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -237,13 +274,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "Argument sequence element cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i] = numpy::PyIntOrPyLongToLong(py_int); if (temps[i] == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } Py_DECREF(py_int); Py_DECREF(o); @@ -257,7 +294,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -268,13 +305,13 @@ tensorflow::ImportNumpy(); PyErr_SetString( PyExc_TypeError, "Argument sequence element cannot be converted to int"); - return NULL; + SWIG_fail; } const int64 handle = numpy::PyIntOrPyLongToLong(py_int); if (handle == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i].set_handle(handle); Py_DECREF(py_int); @@ -289,7 +326,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); @@ -298,7 +335,7 @@ tensorflow::ImportNumpy(); LocalShapedBuffer* lsbp; if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), SWIG_POINTER_EXCEPTION)) == -1) { - return NULL; + SWIG_fail; } temps.push_back(lsbp); Py_DECREF(o); @@ -312,7 +349,7 @@ tensorflow::ImportNumpy(); literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - return NULL; + SWIG_fail; } $1 = literal_status.ValueOrDie().get(); } @@ -324,7 +361,7 @@ tensorflow::ImportNumpy(); %typemap(out) StatusOr< std::unique_ptr > { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); } @@ -332,7 +369,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -341,7 +378,7 @@ tensorflow::ImportNumpy(); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); Py_DECREF(o); @@ -355,7 +392,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::OpMetadataFromPyObject($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -367,7 +404,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -382,7 +419,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -396,7 +433,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -405,7 +442,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -416,7 +453,7 @@ tensorflow::ImportNumpy(); std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -428,7 +465,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -442,18 +479,18 @@ tensorflow::ImportNumpy(); PyObject* py_int = numpy::PyNumberToPyInt($input); if (!py_int) { PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - return NULL; + SWIG_fail; } const long value = numpy::PyIntOrPyLongToLong(py_int); if (value == -1 && PyErr_Occurred()) { Py_DECREF(py_int); - return NULL; + SWIG_fail; } if (!PrimitiveType_IsValid(value)) { PyErr_SetString( PyExc_TypeError, "Argument not valid for PrimitiveType enum"); Py_DECREF(py_int); - return NULL; + SWIG_fail; } $1 = static_cast(value); } @@ -464,19 +501,19 @@ tensorflow::ImportNumpy(); (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (!o) { - return NULL; + SWIG_fail; } PyObject* first = PyTuple_GetItem(o, 0); if (!first) { Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* first_pyint = numpy::PyNumberToPyInt(first); if (!first_pyint) { @@ -484,13 +521,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "First pair item cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* second = PyTuple_GetItem(o, 1); if (!second) { Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } PyObject* second_pyint = numpy::PyNumberToPyInt(second); if (!second_pyint) { @@ -499,21 +536,21 @@ tensorflow::ImportNumpy(); "Second pair item cannot be converted to int"); Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); if (first_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); if (second_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } temps.push_back(std::make_pair(first_value, second_value)); Py_DECREF(o); @@ -531,26 +568,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( $input, "lhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_contracting_dimensions); if (length == -1) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); if (!item) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -561,26 +598,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( $input, "rhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_contracting_dimensions); if (length == -1) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); if (!item) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -591,26 +628,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_batch_dimensions = PyObject_GetAttrString( $input, "lhs_batch_dimensions"); if (!lhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_batch_dimensions); if (length == -1) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); if (!item) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_batch_dimensions(dimension); Py_DECREF(item); @@ -621,26 +658,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_batch_dimensions = PyObject_GetAttrString( $input, "rhs_batch_dimensions"); if (!rhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_batch_dimensions); if (length == -1) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); if (!item) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_batch_dimensions(dimension); Py_DECREF(item); @@ -656,20 +693,20 @@ tensorflow::ImportNumpy(); (PaddingConfig padding_config) { PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); if (!dimensions) { - return NULL; + SWIG_fail; } int length = PySequence_Size(dimensions); if (length == -1) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(dimensions, i); if (!item) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } int64 edge_padding_low, edge_padding_high, interior_padding; if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) @@ -677,7 +714,7 @@ tensorflow::ImportNumpy(); || !GetIntAttr(item, "interior_padding", &interior_padding)) { Py_DECREF(item); Py_DECREF(dimensions); - return NULL; + SWIG_fail; } Py_DECREF(item); @@ -699,32 +736,32 @@ tensorflow::ImportNumpy(); int64 value; if (!GetIntAttr($input, "input_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_batch_dimension(value); if (!GetIntAttr($input, "input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_feature_dimension(value); if (!GetIntAttr($input, "output_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_batch_dimension(value); if (!GetIntAttr($input, "output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_input_feature_dimension(value); @@ -733,24 +770,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "input_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_input_spatial_dimensions(dimension); Py_DECREF(item); @@ -759,24 +796,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_kernel_spatial_dimensions(dimension); Py_DECREF(item); @@ -785,24 +822,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "output_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_output_spatial_dimensions(dimension); Py_DECREF(item); @@ -819,16 +856,32 @@ tensorflow::ImportNumpy(); if ($input == Py_None) { $1 = NULL; } else { - PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); - if (!o) { - return NULL; + if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { + build_options.set_generate_hlo_graph(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { + build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { + build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + + PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); + if (o == NULL) { + SWIG_fail; } if (o != Py_None) { - if (!PyString_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); - return NULL; + if (!PyBool_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); + SWIG_fail; } - build_options.set_generate_hlo_graph(PyString_AsString(o)); + build_options.set_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -841,7 +894,7 @@ tensorflow::ImportNumpy(); if (!statusor.ok()) { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } build_options.set_result_layout(statusor.ValueOrDie()); } @@ -907,6 +960,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; %unignore xla::swig::LocalComputationBuilder::Conditional; +%unignore xla::swig::LocalComputationBuilder::IsConstant; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 3d87480728aab1d4ebbc71c6c7504d37cae5edaf..eec48479c929ab0823fef342fc284bfdc4b1f339 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -170,8 +170,7 @@ static string PyObjectCppStr(PyObject* o) { return ExtractStringAndDecref(s); } -// Safely returns a repr of the given Python object o as a C++ string. -static string PyObjectCppRepr(PyObject* o) { +string PyObjectCppRepr(PyObject* o) { PyObject* r = PyObject_Repr(o); return ExtractStringAndDecref(r); } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index adfcc3b8588dce01718bb19dea936bace483be4d..9656cb1c31c39dbe54293700c2765d0723255657 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -107,6 +107,9 @@ void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { std::copy(source.begin(), source.end(), dest); } +// Safely returns a repr of the given Python object o as a C++ string. +string PyObjectCppRepr(PyObject* o); + // Workarounds for Python 2 and 3 interop PyObject* LongToPyIntOrPyLong(long x); // NOLINT diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 90cda42f3227c80826ffbf4e5473647c2795544d..9c81f6439d0d9f0a0f0d1d3402e9c1ada46e8691 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -320,6 +320,9 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None + self.dump_optimized_hlo_proto_to = None + self.dump_per_pass_hlo_proto_to = None + self.hlo_profile = False def transfer_to_infeed(value, replica_number=None): @@ -1025,6 +1028,20 @@ class ComputationBuilder(object): _unwrap_data_handle(false_operand), false_computation.c_local_computation)) + def IsConstant(self, operand, num_parameters=0): + """Enqueues an IsConstant operation onto the computation. + + Args: + operand: a ComputationDataHandle to test. + num_parameters: optional int, number of computation parameters to treat as + constant (default 0). + + Returns: bool indicating whether `operand` is a compile-time constant, + meaning its value does not depend on parameters with index greater than or + equal to `num_parameters`. + """ + return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 4c16c1f8b07a28d8098e92e27f81a126ed9bdf0c..d97264ea640787ab865f3cd64867addedd73cc1d 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -855,6 +855,17 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testIsConstant(self): + c = self._NewComputation() + a = c.ConstantS32Scalar(3) + b = c.ConstantS32Scalar(1) + x = c.ParameterFromNumpy(NumpyArrayS32(0)) + const_expr = c.Sub(b, a) + non_const_expr = c.Mul(const_expr, x) + self.assertTrue(c.IsConstant(const_expr)) + self.assertFalse(c.IsConstant(non_const_expr)) + # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fba20c94cafea587bffcd766d1122d6327f32182..3a99d84bea63636870609a01c10f2bb3e0e5e8d7 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -285,6 +285,23 @@ cc_library( ], ) +tf_cc_test( + name = "dfs_hlo_visitor_with_default_test", + srcs = ["dfs_hlo_visitor_with_default_test.cc"], + deps = [ + ":hlo", + ":hlo_runner", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], @@ -623,6 +640,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -712,7 +730,6 @@ cc_library( ":computation_layout", ":device_memory_allocator", ":hlo", - ":hlo_cost_analysis", ":hlo_execution_profile", ":hlo_graph_dumper", ":pool", @@ -1129,6 +1146,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1275,6 +1293,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], + deps = [ + ":gather_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:test_macros_header", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "conditional_simplifier", srcs = ["conditional_simplifier.cc"], @@ -1566,6 +1596,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -2619,17 +2650,3 @@ cc_library( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index be7aa307d2c9f70ba8d334b842a4ff29a49687f9..0e4624fd69e623efca780937c5347dbf6bb9afe1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -302,7 +302,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplication on platforms where it causes a slowdown. + // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; }; @@ -385,7 +385,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { auto* c2 = rhs; TF_ASSIGN_OR_RETURN(auto* sum_of_constants, - CreateBinaryHlo(HloOpcode::kAdd, c1, c2)); + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), @@ -636,16 +636,14 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (lhs->opcode() == HloOpcode::kDivide && rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - auto a_times_d, - CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN( - auto b_times_c, - CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), - rhs->mutable_operand(0))); - TF_ASSIGN_OR_RETURN(auto new_divide, CreateBinaryHlo(HloOpcode::kDivide, - a_times_d, b_times_c)); + TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply, + lhs->mutable_operand(0), + rhs->mutable_operand(1))); + TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, + lhs->mutable_operand(1), + rhs->mutable_operand(0))); + TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, + a_times_d, b_times_c)); return ReplaceInstruction(divide, new_divide); } @@ -654,7 +652,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { if (lhs->opcode() == HloOpcode::kDivide) { TF_ASSIGN_OR_RETURN( auto b_times_c, - CreateBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, @@ -663,9 +661,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // A / (B / C) => (A*C) / B if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - auto a_times_c, - CreateBinaryHlo(HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs, + rhs->mutable_operand(1))); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, @@ -1124,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { auto operand = broadcast->mutable_operand(0); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1145,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1160,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1204,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector new_dimensions; + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } @@ -1300,8 +1309,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, - CreatePadHlo(pad->mutable_operand(0), - pad->mutable_operand(1), nonzero_padding)); + MakePadHlo(pad->mutable_operand(0), + pad->mutable_operand(1), nonzero_padding)); // Copy the layout from the original pad instructions. The new pad and the // slice instruction should all have the same layout. TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -1329,7 +1338,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { TF_ASSIGN_OR_RETURN( HloInstruction * slice, - CreateSliceHlo(nonzero_pad, start_indices, end_indices, strides)); + MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); // Verify that the slice shape matches the pad shape. TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); @@ -1722,18 +1731,29 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } - VLOG(10) << "Considering folding Pad: " << operand->ToString() - << "\ninto reduce-window: " << reduce_window->ToString(); - // This optimization folds a pad op into reduce_window. - if (operand->opcode() != HloOpcode::kPad) { + HloInstruction* pad; + const HloInstruction* convert = nullptr; + if (operand->opcode() == HloOpcode::kPad) { + pad = operand; + } else if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->opcode() == HloOpcode::kPad) { + convert = operand; + pad = operand->mutable_operand(0); + } else { VLOG(10) << "Not folding pad into reduce-window as there is no pad."; return Status::OK(); } + VLOG(10) << "Considering folding Pad: " << pad->ToString() + << "\ninto reduce-window: " << reduce_window->ToString() + << (convert != nullptr ? tensorflow::strings::StrCat( + "\nvia convert: ", convert->ToString()) + : ""); + // Do not fold interior padding into ReduceWindow since the backends do not // support it. - const PaddingConfig& pad_config = operand->padding_config(); + const PaddingConfig& pad_config = pad->padding_config(); if (HasInteriorPadding(pad_config)) { VLOG(10) << "Not folding pad into reduce-window due to interior padding."; return Status::OK(); @@ -1741,14 +1761,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If reduce_window already has padding, the pad value of the pad op and the // init value of reduce_window must match to allow folding the pad. - const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* pad_value = pad->operand(1); const HloInstruction* reduce_init_value = reduce_window->operand(1); if (pad_value != reduce_init_value) { + auto literals_are_equivalent = [&] { + auto& pad_literal = pad_value->literal(); + auto& reduce_init_literal = reduce_init_value->literal(); + if (pad_literal == reduce_init_literal) { + return true; + } + auto converted_pad_literal = pad_literal.ConvertToShape( + reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + if (!converted_pad_literal.ok()) { + return false; + } + return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - pad_value->literal() != reduce_init_value->literal()) { + !literals_are_equivalent()) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1757,7 +1790,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If the pad puts a single non-identity value in each window that we're // reducing, then this is a broadcast. - HloInstruction* pad_operand = operand->mutable_operand(0); + HloInstruction* pad_operand = pad->mutable_operand(0); auto is_effective_broadcast = [&] { if (window_util::HasStride(window)) { VLOG(10) << "Window has stride."; @@ -1801,6 +1834,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Found window covers a single unpadded element."; return true; }; + + HloInstruction* new_reduce_window_operand; + if (convert != nullptr) { + new_reduce_window_operand = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad_operand->shape(), + convert->shape().element_type()), + pad_operand)); + } else { + new_reduce_window_operand = pad_operand; + } + if (is_effective_broadcast()) { VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; auto fadd = [this](std::unique_ptr x) { @@ -1809,7 +1854,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcastSequence( /*output_shape=*/reduce_window->shape(), - /*operand=*/pad_operand, fadd)); + /*operand=*/new_reduce_window_operand, fadd)); } // Carry out the folding of the pad into reduce_window. @@ -1826,10 +1871,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( window_dim.set_padding_high(window_dim.padding_high() + pad_dim.edge_padding_high()); } + return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/pad_operand, + /*operand=*/new_reduce_window_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 43315f5cdc7afbe79039420320f4a0d0535e11f1..c48196e861a559a5abfa360841ec70b39356fa2b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs AlgebraicSimplications. +// A pass which performs algebraic simplifications. class AlgebraicSimplifier : public HloPassInterface { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to @@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface { bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; - // Enable dot simplication on platforms where it is profitable. + // Enable dot simplification on platforms where it is profitable. bool enable_dot_strength_reduction_; - // Enable convolution simplication on platforms where it is profitable. + // Enable convolution simplification on platforms where it is profitable. bool enable_conv_simplification_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 451294ef5d8367686d7fc22b7f5ebfde89d14d42..20c549562d5153c802c1e675a8ff1c92426b8832 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" +using ::testing::ElementsAre; + namespace xla { namespace { @@ -2336,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); } +// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to +// ReduceWindow(Convert(op), x). +TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); + + HloInstruction* convert = + builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad->shape(), F32), pad)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, convert, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); @@ -2462,6 +2549,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* input_array = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, input_array, {1})); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root->dimensions(), ElementsAre(2)); +} + +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + // The initial dimensions go to places 0 and 2 in the 3-dim array, + // and to places 1 and 3 in the 4-dim array, + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, param0, {0, 2})); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 432448e9bbc7db30ed67a0130d52b060032362d5..08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,6 +34,9 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); @@ -84,6 +87,25 @@ Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( return Status::OK(); } +namespace { + +// Returns whether hlo has users and all users are conversions from F32 to BF16. +bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) { + if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) { + return false; + } + for (const auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + return false; + } + return true; +} + +} // namespace + Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( HloInstruction* hlo) { std::vector bf16_to_f32_operands; @@ -104,22 +126,9 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( } } - bool fold_output_conversion = hlo->user_count() > 0 && - hlo->shape().element_type() == F32 && - bfloat16_support_->SupportsBF16Output(*hlo) && - hlo != computation_->root_instruction(); - if (fold_output_conversion) { - for (auto user : hlo->users()) { - if (user->opcode() == HloOpcode::kConvert && - user->shape().element_type() == BF16) { - continue; - } - // We should not change the output type if any user is not a conversion - // from F32 to BF16. - fold_output_conversion = false; - break; - } - } + const bool fold_output_conversion = + AllUsersAreF32ToBF16Converts(hlo) && + bfloat16_support_->SupportsBF16Output(*hlo); if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { if (has_other_f32_operands || @@ -171,6 +180,52 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } +Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape()) || + !bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return DefaultAction(crs); + } + + // First use DefaultAction() to handle the operands. It can't handle + // tuple-shaped output. + TF_RETURN_IF_ERROR(DefaultAction(crs)); + + // Then do per-tuple-element handling on the output. + std::vector> per_tuple_element_gtes( + crs->operand_count()); + for (auto user : crs->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return Status::OK(); + } + per_tuple_element_gtes[user->tuple_index()].push_back(user); + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + // Fold conversions only when all the get-tuple-elements' users are + // conversions from F32 to BF16. + auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { + for (auto gte : per_tuple_element_gtes[i]) { + if (!AllUsersAreF32ToBF16Converts(gte)) { + return false; + } + } + return true; + }; + if (!all_gte_users_are_bf16_convert()) { + continue; + } + + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) + ->set_element_type(BF16); + for (auto gte : per_tuple_element_gtes[i]) { + TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); + } + } + + return Status::OK(); +} + StatusOr BFloat16ConversionFolding::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index cb37759439debf41a305ec7dccaa548e1bf234cd..28e71c2054f59ba4d5d096bf7d898161877bb42f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -37,7 +37,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -47,7 +48,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -55,7 +57,8 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -206,4 +209,46 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } +TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* convert_a = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + HloInstruction* gte_a = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); + HloInstruction* gte_b = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 1)); + HloInstruction* convert_gte_b = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({gte_a, convert_gte_b})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_EQ(tuple->operand(0), gte_a); + EXPECT_EQ(tuple->operand(1), gte_b); + EXPECT_EQ(gte_a->shape().element_type(), F32); + EXPECT_EQ(gte_b->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0), a); + EXPECT_EQ(crs->operand(1), b); + EXPECT_EQ(a->shape().element_type(), BF16); + EXPECT_EQ(b->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 531f36e8c5473ef684e654ed6b89c4d5ef04b401..c26d2feef584faeff013a602409cdd58c2d44a5a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -606,8 +606,10 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN(auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape())); + TF_ASSIGN_OR_RETURN( + auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape(), + /*round_f32_to_bf16=*/true)); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); @@ -627,6 +629,27 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( return Status::OK(); } +Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { + for (auto computation : module->computations()) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConvert) { + continue; + } + auto source = hlo->mutable_operand(0); + if (!ShapeUtil::Equal(source->shape(), hlo->shape())) { + continue; + } + const bool is_root = hlo == computation->root_instruction(); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source)); + if (is_root) { + computation->set_root_instruction(source); + } + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); + } + } + return Status::OK(); +} + // The algorithm first does a forward pass (parameters to root) to determine a // set of instructions to consider using bfloat16, then does a backward pass to // determine the precisions of those instructions according to the need of @@ -677,6 +700,10 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // defining instruction's shape has changed. So we need to adjust the output // shapes of instructions according to the HLO values they refer to. TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + + // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> + // BF16), so we remove them now. + TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); return true; } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 89a5ac5db1549877a135182ae8df57fa6bf9d579..1744e9db90aeff269daa91eb68a1d61bb0fc3035 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -133,6 +133,11 @@ class BFloat16Propagation : public HloPassInterface { // by the given HLO. void AdjustCalledComputationRoot(HloInstruction* hlo); + // *************************** + // Removes no-op conversions (same source and target shapes) that can be + // produced this pass. + Status RemoveNoopConversions(HloModule* module); + // *************************** // Functions called and state used by two or more passes. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5950b004b3da439c442eec6e5e09ea2307fcb018..88f83014164ff726a11e45e762b9c082cf12720d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -617,4 +617,44 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { EXPECT_EQ(computation->root_instruction(), dot); } +// Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 -> +// BF16 conversion), then it will remove that conversion. +TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "param")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1)); + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape, HloOpcode::kAdd, convert0, convert1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), add2); + EXPECT_EQ(add2->operand(0), gte0); + EXPECT_EQ(add2->operand(1), gte1); + EXPECT_EQ(gte0->shape().element_type(), BF16); + EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 6664496ab6c603c35c7dce923fcf94c54d1ce714..c83da9eddc8f8b156dd9acfc99b393bf844575da 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -100,7 +100,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, *user_computation)); + &execution_options, user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 33e19efc72c6d30ccd7e0b3a13f664a4f42208bf..b4b53ae2ed425a48de5bcb6ba5c37b5d37e1f371 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -127,7 +127,7 @@ class Compiler { // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses - // prior to calling this method because the some HLO passes are required for + // prior to calling this method because some HLO passes are required for // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index e9c974a0461da4b79b4d4cf7a15f407ead5eb4bb..40519ecc799c8f0343294ad88009820dbd8535e9 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -78,8 +78,9 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, policy.copy_root_replicated_buffers = true; } for (const CallSite& site : node.caller_callsites()) { - // The kWhile instruction does not have an handling here, as the - // AddCopiesForWhile() API takes care of adding its own copies. + // The AddCopiesForConditional() already adds copies, but the copy remover + // removes them, so we re-add them by returning the policy here. But really + // the copy remover should not be removing them. if (site.instruction()->opcode() == HloOpcode::kConditional) { policy.copy_parameters_and_constants = true; policy.copy_root_replicated_buffers = true; @@ -321,6 +322,29 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// We add copies for all the indices of the true and false computaiton roots, +// in order to resolve interference. We later rely on the CopyRemover to drop +// the unnecessary ones. +Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, + HloInstruction* conditional) { + VLOG(2) << "Adding copies for kConditional instruction " + << conditional->name(); + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + + for (HloComputation* computation : + {conditional->true_computation(), conditional->false_computation()}) { + HloInstruction* root = computation->root_instruction(); + std::vector users = root->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + computation->DeepCopyInstruction(root)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); + } + computation->set_root_instruction(deep_copy); + } + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -348,6 +372,9 @@ Status AddCopiesToResolveInterference(HloModule* module) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); } } } @@ -596,6 +623,7 @@ class CopyRemover { auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; if (LiveRangeBefore(a, b)) { VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " << b.value->ToShortString(); @@ -610,7 +638,7 @@ class CopyRemover { VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -786,16 +814,16 @@ class CopyRemover { // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { if (a.uses.empty()) { - VLOG(2) << "Empty uses"; + VLOG(2) << "Empty uses for " << *a.value; return ordering_.IsDefinedBefore(*a.value, *b.value); } for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; + VLOG(2) << "Checking use " << *use << " against " << *b.value; if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; return false; } + VLOG(2) << "Use " << *use << " is before " << *b.value; } return true; } @@ -931,7 +959,6 @@ Status RemoveUnnecessaryCopies( CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); - tensorflow::gtl::FlatSet existing_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && @@ -940,7 +967,6 @@ Status RemoveUnnecessaryCopies( } } } - return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 91ae66ece11e70459db9a62782d3c24a303829c2..966e2d0fc5b5e21180795a07119cb028913dd176 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -670,6 +670,22 @@ cc_library( ], ) +tf_cc_test( + name = "ir_emission_utils_test", + srcs = ["ir_emission_utils_test.cc"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], @@ -772,6 +788,31 @@ cc_library( ], ) +tf_cc_test( + name = "parallel_task_assignment_test", + srcs = ["parallel_task_assignment_test.cc"], + deps = [ + ":cpu_executable", + ":parallel_task_assignment", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], @@ -875,17 +916,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 0a966fd5a7c1ce2c4e367b26701c9186ab2ebf74..e43777c5e5e8afcf08e1e334c8847f6b94d0d047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -318,7 +318,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Note this is not run for AOT because it would bring in thread pool // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). - // TODO(29630486) Support multi-threaded AOT. + // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 267b89a10b3c038dc2048f0ad5b5b343c88ef0f9..d3502b3a03e27c8f90ed74c4d826dfab1c4e8b75 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -71,11 +71,6 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU executable. - return Unimplemented("Equality test on CPU executable is not implemented."); - } - static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 6f06256e08e8e3342e77c7c79a2a47465b89eca3..8b1e20d79e90fcc32e985ffb855a1a10cdd2f2b9 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -715,6 +715,11 @@ tensorflow::Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK( + dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); + const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 788217aab6172b4e548452b3f6ffd4197c163ce4..f209a69e3cd0f8d336d61bafd1e22be8bc88ca3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution( // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); - const Shape& kernel_shape = convolution.operand(0)->shape(); + const Shape& kernel_shape = convolution.operand(1)->shape(); if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // Make sure input and kernel has the same data type. + CHECK( + ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape) || - ShapeUtil::ElementIsComplex(kernel_shape)) { + if (ShapeUtil::ElementIsComplex(input_shape)) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..215f48c4cc1a1a6b13d98dff76e0d1f0f773f5c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) { + const char* const hlo_string = R"( +HloModule ModuleWithConv + +ENTRY Conv { + input = f32[32,50,28,28]{3,2,1,0} parameter(0) + kernel = f32[0,32,5,5]{3,2,1,0} parameter(1) + ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel), + window={size=5x5}, + dim_labels=b01f_01io->b01f +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* entry_computation = module->entry_computation(); + + HloInstruction* conv_instr = entry_computation->root_instruction(); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3b8056d50500cac381a1c5ad6b05028476504a47..3405277d449f2d9e558f2d3f83277163655af592 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -438,12 +438,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, - length_32, 1); + ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, + acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, - length_32, 1); + ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, + program_buffer_address, + /*SrcAlign=*/1, length_32); } ir_builder_.CreateCall(release_func, @@ -2441,7 +2443,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { auto* memcpy_instruction = ir_builder_.CreateMemCpy( - target, source, element_count * primitive_type_size, element_alignment); + target, /*DstAlign=*/element_alignment, source, + /*SrcAlign=*/element_alignment, element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. @@ -2905,7 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index c393e9b8ea39bfb4c605ebba8e2cd29726bc4af9..87c0a3df458eb4b3f217192597e0de1576304367 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -83,12 +83,6 @@ class ParallelCpuExecutable : public Executable { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU parallel executable. - return Unimplemented( - "Equality test on CPU parallel executable is not implemented."); - } - private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index deb21bf4ef5895cfdbec5c2449b6ce7b306a7008..fb28280fade307ac1f193e7dca481bd2afa855fc 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -71,7 +71,7 @@ class DefaultCostModel : public ParallelCostModel { if (flops_to_bytes_ratio <= 1.0) { // Limit max parallelism for I/O bound instructions by assuming a // sub-linear scaling function (fit based on empirical benchmark results). - // TODO(29630486) Develop system bandwidth model. + // TODO(b/29630486) Develop system bandwidth model. max_parallelism = std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); // Use shape size instruction cost and L2 cache size min per-thread cost. @@ -81,7 +81,7 @@ class DefaultCostModel : public ParallelCostModel { // Use max parallelism for compute bound instructions. max_parallelism = max_parallelism_; // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. + // TODO(b/29630486) Improve on this linear cost model. // Consider making 'min_cost_per_thread' be a function of the target // bandwidth limit for instructions with low arithmetic complexity. instruction_cost = @@ -128,24 +128,25 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - instruction->opcode() == HloOpcode::kGetTupleElement || - instruction->opcode() == HloOpcode::kBitcast || - instruction->opcode() == HloOpcode::kFft || - (instruction->opcode() == HloOpcode::kConvolution && + auto opcode = instruction->opcode(); + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || + opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || + opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || + opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && + (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { return 1; } + // Consult 'cost_model_' to compute target parallel task count. return cost_model_->GetParallelTaskCount(instruction); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..13eb75a57213b1a68a5732a4f6061efdf97fa4f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class ParallelTaskAssignmentTest : public HloVerifiedTestBase { + protected: + const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = + cpu::CpuExecutable::ShapeSizeBytes; + + // Use any value larger than 2 since we only test whether a module is + // parallelized or not + const int max_parallelism_ = 10; +}; + +TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_Dot + ENTRY Dot { + dot_lhs = f32[196614,2]{1,0} parameter(0) + dot_rhs = f32[2,1]{1,0} parameter(1) + ROOT dot = f32[196614,1]{1,0} dot(dot_lhs, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, + FusedComputationWithDotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_DotNestedInFusedComp + fused_computation.0 { + parameter.0 = f32[196614,2]{1,0} parameter(0) + parameter.0.1 = f32[2,1]{1,0} parameter(1) + parameter.0.2 = f32[196614,1]{1,0} parameter(2) + dot.0 = f32[196614,1]{1,0} dot(parameter.0, parameter.0.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add.0 = f32[196614,1]{1,0} add(dot.0, parameter.0.2) + + } + ENTRY DotNestedInFusedComp { + parameter = f32[196614,2]{1,0} parameter(0) + parameter.1 = f32[2,1]{1,0} parameter(1) + parameter.2 = f32[196614,1]{1,0} parameter(2) + ROOT fusion = f32[196614,1]{1,0} fusion(parameter, parameter.1, + parameter.2), kind=kOutput, calls=fused_computation.0 + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_rng + ENTRY Rng { + src0 = f32[] parameter(0) + src1 = f32[] parameter(1) + ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1), + distribution=rng_uniform + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_infeed_outfeed + ENTRY InfeedOutfeed { + infeed0 = u32[12345678,2]{1,0} infeed() + ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 61b408b8c24dded134218110d4e219c31f1685a8..42fe955f1917e0268dc739e44fbd0a7afb39185c 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -20,12 +20,13 @@ namespace cpu { std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { // Gather outer-most dims where dim_size >= 'target_partition_count'. - // Note: always leave inner-dim static for vectorization/optimizations. + // This may include the inner-dim as LLVM can vectorize loops with dynamic + // bounds. std::vector outer_dims; int64 outer_dim_size = 1; // TODO(b/27458679) Consider reserving enough minor dimensions (based on // target vector register width) to enable vector instructions. - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { const int64 dimension = shape_.layout().minor_to_major(i); outer_dims.push_back(dimension); outer_dim_size *= shape_.dimensions(dimension); diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ee0c53fa6d7c41481a53350e57e5844dea2644c1..ae80a6f4977f85cfd9f872734fd0a69432a1f382 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -30,105 +30,65 @@ class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; - void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + void RunR2Test(const Shape& shape, int64 max_target_partition_count, + const std::vector* expected_partitions) { ShapePartitionAssigner assigner(shape); - // Check all partitions of outer dimension. - for (int64 i = 1; i <= expected_max_partition_count; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), - assigner.Run(/*target_partition_count=*/i))); + // Iterate through 1..max_target_partition_count. + for (int64 i = 1; i <= max_target_partition_count; ++i) { + std::vector actual_partitions = + assigner.Run(/*target_partition_count=*/i); + EXPECT_THAT(actual_partitions, expected_partitions[i - 1]); } - // Check target_partition_count > outer dimension size. - EXPECT_TRUE(ContainersEqual( - Vec({expected_max_partition_count}), - assigner.Run( - /*target_partition_count=*/expected_max_partition_count + 1))); } }; TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); + std::vector expected_partitions[] = {{1} /* 1 */, {1, 2} /* 2 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); + std::vector expected_partitions[] = { + {1} /* 1 */, {1, 2} /* 2 */ + }; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); + std::vector expected_partitions[] = {{1} /* 1 */, {2} /* 2 */, + {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 6, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 4, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 5; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {4, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {4, 3} /* 12 */, + {4, 3} /* 13 */, {4, 3} /* 14 */, {5, 3} /* 15 */, {4, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}), 16, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 3; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */, + {2, 2} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {3, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {3, 4} /* 12 */, + {3, 4} /* 13 */, {3, 4} /* 14 */, {3, 5} /* 15 */, {3, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}), 16, + expected_partitions); } class ShapePartitionIteratorTest : public HloTestBase { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 80c24eaccfc2a83f8f3f311d60860715668d0c08..4198260a222d89c60b58dc2a11bf955715365952 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -87,7 +87,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - execution_session_(string_pool_), symbol_resolver_(llvm::orc::createLegacyLookupResolver( [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index aaeff2de8785b99d271f13b261c63118bcf7bd4a..f4260a95bc45557b6cd969f7d3fff01c8b392575 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -102,7 +102,6 @@ class SimpleOrcJIT { std::unique_ptr target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; - llvm::orc::SymbolStringPool string_pool_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ecda5288ee17a3856ce95f0caa327c3524fd180b..240faebe62f5cee4f61b3c36b5e8f653cfd6db8e 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -35,6 +35,12 @@ class HloInstruction; // DfsHloVisitor with default action based on the HloInstruction being visited. // Users should not use this class directly, but use the type aliases // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. +// +// Do *not* add an override to this class if the opcode is covered by +// HandleElementwiseUnary/Binary. These opcode handlers dispatch to +// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler +// here will break passes which rely on the HandleElementwiseUnary/Binary +// handling these opcodes. template class DfsHloVisitorWithDefaultBase : public DfsHloVisitorBase { @@ -70,12 +76,6 @@ class DfsHloVisitorWithDefaultBase Status HandleConcatenate(HloInstructionPtr concatenate) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstructionPtr convert) override { - return DefaultAction(convert); - } - Status HandleCopy(HloInstructionPtr copy) override { - return DefaultAction(copy); - } Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } @@ -91,9 +91,6 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleCompare(HloInstructionPtr compare) override { - return DefaultAction(compare); - } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class DfsHloVisitorWithDefaultTest : public HloTestBase {}; + +TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) { + // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called + // on the appropriate HLO ops (elementwise binary/unary ops). + + class ElementwiseTestVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo) override { + // The HLO should be neither an elementwise unary nor binary op. These + // cases are handled in HandleElementwiseBinary/Unary. + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2)) + << hlo->ToString(); + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1)) + << hlo->ToString(); + return Status::OK(); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + // HLO should be elementwise binary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2) + << hlo->ToString(); + return Status::OK(); + } + Status HandleElementwiseUnary(HloInstruction* hlo) override { + // HLO should be elementwise unary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1) + << hlo->ToString(); + return Status::OK(); + } + }; + + // HLO module contains are arbitrary mix of elementwise and non-elementwise + // operations. + const string& hlo_string = R"( +HloModule TestModule + +ENTRY TestComputation { + arg = f32[] parameter(0) + tuple = (f32[]) tuple(arg) + gte = f32[] get-tuple-element(tuple), index=0 + abs = f32[] abs(arg) + add = f32[] add(arg, gte) + broadcast = f32[42] broadcast(add), dimensions={} + slice = f32[0] slice(broadcast), slice={[1:2]} + copy = f32[] copy(arg) + eq = pred[] equal-to(arg, gte) + neg = f32[] negate(arg) + ROOT convert = f64[] convert(f32[] arg) +})"; + std::unique_ptr module = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) + .ConsumeValueOrDie(); + ElementwiseTestVisitor visitor; + TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index be92b1629a2d8dae57b315751bd4f7f9ccddf171..471d2fd6cebcd7a00dfea4aca08da08af534b05f 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -80,6 +80,7 @@ StatusOr> Executable::ExecuteOnStreamWrapper( StatusOr> return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); + TF_RETURN_IF_ERROR(return_value.status()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 0aee535ee780ef000bc5e9963ff48786b3a61eb2..a157235f8af6ea64a488510e427bbae502c46ca6 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -109,14 +108,6 @@ class Executable { return execution_profile_; } - // Returns Status::ok() if the two executables are equal to each other. - // - // An error status is returned otherwise. - virtual const Status EqualOrFail(const Executable& executable) { - return Unimplemented( - "Equality test on this executable is not implemented."); - } - const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); return *hlo_profile_printer_data_; diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index a133d810675814f6be7da23a2335fb19f3ff47fc..221ff7900f398166c193c495848a2afcfd4edc81 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -39,7 +39,7 @@ static StatusOr TransposeIndexVectorDimToLast( } } permutation.push_back(index_vector_dim); - return CreateTransposeHlo(gather_indices, permutation); + return MakeTransposeHlo(gather_indices, permutation); } // If the gather_indices holds scalar indices (i.e. gather_indices has rank N @@ -53,9 +53,14 @@ static StatusOr DeScalarizeGatherIndices( return gather_indices; } - int64 last_index = gather_indices_shape.dimensions( - gather_indices_shape.dimensions_size() - 1); - return ExpandLastDimIntoNDims(gather_indices, {last_index, 1}); + DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size()); + + std::vector result_shape_dims; + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(result_shape_dims)); + result_shape_dims.push_back(1); + + return MakeReshapeHlo(result_shape_dims, gather_indices); } // Canonicalizes the gather_indices tensors so that we only have deal with some @@ -81,16 +86,17 @@ static StatusOr CanonicalizeGatherIndices( // all of the non-index-vector dimensions. const Shape& shape = transposed_gather_indices->shape(); if (shape.dimensions_size() == 1) { - return ExpandFirstDimIntoNDims(gather_indices, {1, shape.dimensions(0)}); + return ExpandFirstDimIntoNDims(transposed_gather_indices, + {1, shape.dimensions(0)}); } else { return CollapseFirstNDims(transposed_gather_indices, shape.dimensions_size() - 1); } } -// Expands out the gather dimensions in the accumulator produced by the while -// loop. -static StatusOr ExpandGatherDimsInAccumulator( +// Expands out or contracts away the gather dimensions in the accumulator +// produced by the while loop. +static StatusOr AdjustGatherDimsInAccumulator( const Shape& gather_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { std::vector output_gather_dim_bounds; @@ -103,9 +109,14 @@ static StatusOr ExpandGatherDimsInAccumulator( if (output_gather_dim_bounds.empty()) { // If output_gather_dim_bounds is empty we must be lowering a (effectively) - // dynamic-slice. + // dynamic-slice. In that case, there is a leading degenerate gather + // dimension that we added to make this special case play well with the + // general while loop which we need to remove now. CHECK_EQ(accumulator->shape().dimensions(0), 1); - return CollapseFirstNDims(accumulator, 2); + ArraySlice reshaped_dim_sizes = + AsInt64Slice(accumulator->shape().dimensions()); + reshaped_dim_sizes.remove_prefix(1); + return MakeReshapeHlo(reshaped_dim_sizes, accumulator); } return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); @@ -133,16 +144,16 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( dim_numbers.gather_dims_to_operand_dims_size()) { TF_ASSIGN_OR_RETURN( HloInstruction * component_to_concat, - CreateSliceHlo( - index_vector, /*start_indices=*/{index_vector_dim_index}, - /*limit_indices=*/{index_vector_dim_index + 1}, /*strides=*/{1})); + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); expanded_index_components.push_back(component_to_concat); } else { expanded_index_components.push_back(zero); } } - return CreateConcatHlo(expanded_index_components, /*dimension=*/0); + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } // This generates the body of the while that implements the main data movement @@ -159,8 +170,8 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * induction_var_as_vector, - CreateBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); TF_ASSIGN_OR_RETURN( HloInstruction * index_into_gather_indices, @@ -169,8 +180,8 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, - CreateDynamicSliceHlo(gather_indices, index_into_gather_indices, - {1, index_vector_size})); + MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + {1, index_vector_size})); TF_ASSIGN_OR_RETURN(HloInstruction * index_vector, ElideDegenerateDims(index_vector_2d, {0})); @@ -181,8 +192,8 @@ static StatusOr> GatherLoopBody( operand->shape().dimensions_size())); TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, - CreateDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_window_bounds())); + MakeDynamicSliceHlo(operand, gathered_slice_start, + gather.gather_window_bounds())); TF_ASSIGN_OR_RETURN( HloInstruction * gathered_slice_for_update, @@ -197,8 +208,8 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * updated_accumulator, - CreateDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, - index_vector_into_accumulator)); + MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, + index_vector_into_accumulator)); // New loop state -- only the accumulator has changed. The // WhileUtil::MakeCountedLoop functions takes care of the induction variable @@ -250,7 +261,7 @@ static StatusOr PermuteGatherAndWindowDims( } } - return CreateTransposeHlo(accumulator, permutation); + return MakeTransposeHlo(accumulator, permutation); } // High Level Algorithm @@ -290,21 +301,38 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { + CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* gather_indices = gather_instr->mutable_operand(1); + const Shape& gather_indices_shape = gather_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); const GatherDimensionNumbers& dim_numbers = gather_instr->gather_dimension_numbers(); + int64 gather_loop_trip_count = 1; + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + gather_loop_trip_count *= gather_indices_shape.dimensions(i); + } + } + + if (!IsInt32(gather_loop_trip_count)) { + return Unimplemented( + "Gather operations with more than 2147483647 gather indices are not " + "supported. This error occurred for %s.", + gather_instr->ToString().c_str()); + } + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, CanonicalizeGatherIndices( gather_indices, dim_numbers.index_vector_dim())); - const int64 gather_loop_trip_count = - canonical_gather_indices->shape().dimensions(0); + CHECK_EQ(gather_loop_trip_count, + canonical_gather_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, @@ -331,7 +359,7 @@ StatusOr GatherExpander::ExpandGather( TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - ExpandGatherDimsInAccumulator(gather_indices->shape(), + AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_with_window_dims_elided, dim_numbers.index_vector_dim())); @@ -341,12 +369,17 @@ StatusOr GatherExpander::ExpandGather( } StatusOr GatherExpander::Run(HloModule* module) { + auto is_nontrivial_gather = [](HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::HasZeroElements(inst->shape()); + }; + std::vector gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - [](HloInstruction* inst) { - return inst->opcode() == HloOpcode::kGather; - }); + is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba41ee8428cbe7132103df24d552565a8dc2f9f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { +TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2147483647,5] parameter(1) + ROOT gather = s32[2147483647,3,5] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + Status status = GatherExpander{}.Run(module.get()).status(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + + ASSERT_THAT( + status.error_message(), + ::testing::HasSubstr("Gather operations with more than 2147483647 gather " + "indices are not supported.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a3b7e10ae8df080879ce98b02b83f246bb19204b..f1707442fe3354d5183d905468810f3871146ff5 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -241,6 +241,7 @@ cc_library( "gpu_executable.cc", "infeed_thunk.cc", "kernel_thunk.cc", + "memset_thunk.cc", "sequential_thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", @@ -257,6 +258,7 @@ cc_library( "gpu_executable.h", "infeed_thunk.h", "kernel_thunk.h", + "memset_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -273,6 +275,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -293,6 +296,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep + "//tensorflow/stream_executor", ], ) @@ -696,17 +700,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index c67b552abbdc971351f99ec89536af78479b87c1..07be2a0cf90c326af6e41764e79950db546e43e4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -671,6 +671,8 @@ StatusOr> GpuCompiler::RunBackend( if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 04b37d913e0bc8f8226057f107da05fd1e675010..28f93447953b90d8a7fa4386e2355066c0405aec 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -267,16 +267,22 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - // The caller must give us a buffer for ShapeIndex {} of every parameter. - // It can optionally give us a buffer for other ShapeIndices, but we - // ignore them: Because we can't rely on these sub-buffers' addresses - // being available, our generated code can't use them. Instead, it must - // chase pointers starting at the tuple root. - if (allocation.param_shape_index().empty()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->root_buffer()); + auto param_no = allocation.parameter_number(); + se::DeviceMemoryBase buffer = + arguments[param_no]->buffer(allocation.param_shape_index()); + + // All top-level buffers and sub-buffers must have an explicit, non-null + // pointer, except for zero-sized buffers, which may be null. + if (buffer.is_null() && buffer.size() > 0) { + return FailedPrecondition( + "Cannot run XLA computation because pointer to (sub-)buffer at " + "index %s of parameter %lld was null. All pointers to " + "(sub-)buffers must not be null, unless the (sub-)buffer has zero " + "elements.", + allocation.param_shape_index().ToString().c_str(), param_no); } + + buffer_allocations_builder.RegisterBuffer(i, buffer); } } se::StreamExecutor* executor = run_options->stream()->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index b19cfd43debd0a5490495d176fa2f1fcd625da07..dcb3991f41a31db84d8e9e555ae7d13c3ac84b97 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -83,11 +83,6 @@ class GpuExecutable : public Executable { const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on GPU executable. - return Unimplemented("Equality test on GPU executable is not implemented."); - } - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 2381d7a7d59ba2777e711138779b4493b8037f3d..d29cc21ab1c697f8481ed1e94846d4df5ec5c1dc 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -44,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -142,37 +145,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -// Tries to get a Slice for the given instruction at the given index, but -// returns nullopt if we might not know the slice's address at runtime without -// dereferencing a containing tuple. -// -// In particular, when XLA accepts a parameter of tuple type, the caller has the -// option of telling XLA what are the values inside of the tuple, or just giving -// XLA a pointer to the top-level tuple and letting us chase the pointers on the -// GPU. We therefore cannot rely having these pointers to parameter sub-buffers -// being present when we run the program. -optional GetKnownAtRuntimeSlice( - const HloInstruction* instr, const ShapeIndex& index, - const BufferAssignment& buffer_assn) { - auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); - if (!maybe_slice.ok()) { - return nullopt; - } - // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, - // but we don't necessarily know the runtime address of sub-buffers of input - // parameters. - const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); - const BufferAllocation* alloc = slice.allocation(); - if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && - !alloc->param_shape_index().empty()) { - return nullopt; - } - - // Otherwise, we will know the address of this slice at runtime without having - // to dereference a tuple. - return slice; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -203,7 +175,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -213,13 +185,13 @@ bool ImplementedAsDeviceToDeviceMemcpy( // // 1. `hlo` is a kCopy instruction. // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. The operand to `hlo` has a buffer assignment (constants do not, for - // instance) which means the source buffer also resides on the device. + // 3. `hlo` and its operand have a statically-known buffer assignment + // (constants do not, for instance), which means the source buffer also + // resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && - GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) - .has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && + buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); } } // namespace @@ -498,12 +470,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(fusion)); - TF_RETURN_IF_ERROR(EmitInitializer( - fusion, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(fusion)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; @@ -1635,14 +1606,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { if (IsReductionToVector(*reduce) && // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(reduce)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(reduce)); - TF_RETURN_IF_ERROR(EmitInitializer( - reduce, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(reduce)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(reduce)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); + return EmitReductionToVector( reduce, input->shape(), [&](const llvm_ir::IrArray::Index& index) { @@ -1706,16 +1677,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter( CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); CHECK_EQ(rank, window.dimensions_size()); - { - std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - TF_RETURN_IF_ERROR(EmitInitializer( - select_and_scatter, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(select_and_scatter)); + std::vector> thunks; + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1960,38 +1928,54 @@ GetHloBufferSlices(const HloInstruction* hlo, -> optional> { // Simple, common case: Is the buffer for instr known at runtime? If so, // we're done. - auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); - if (slice.has_value()) { - return {{*slice, ShapeIndex()}}; + auto slice = buffer_assn.GetUniqueSlice(instr, index); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } - // If we don't know the buffer for instr at index, see if we know the buffer - // for instr at index without its last element. If so, we can dynamically - // find the buffer for instr by dereferencing a pointer in that buffer. - // Continue looking this way until we run out of elements in 'index'. - ShapeIndex new_index = index; - ShapeIndex gte_indices; - while (!new_index.empty()) { - gte_indices.push_front(new_index.back()); - new_index.pop_back(); - auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + // If that didn't work, walk up any bitcasts that we might see. These must + // appear before any GTE instructions, because it's illegal to bitcast to a + // tuple type. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kBitcast) { + parent = parent->operand(0); + + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } } - // If *that* didn't work, check whether instr is a GTE instruction. If it - // is, see if we can get a buffer for its parent, and continue walking up - // parents until we find a defined buffer or we hit something that's not a - // GTE. - const HloInstruction* parent = instr; + // Check whether instr is a GTE instruction. If it is, see if we can get a + // buffer for its parent, and continue walking up parents until we find a + // defined buffer or we hit something that's not a GTE. + ShapeIndex gte_indices; while (parent->opcode() == HloOpcode::kGetTupleElement) { gte_indices.push_front(parent->tuple_index()); parent = parent->operand(0); - auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; + } + } + + // Finally, if we don't know the buffer for instr at index, see if we know + // the buffer for instr at index without its last element. If so, we can + // dynamically find the buffer for instr by dereferencing a pointer in that + // buffer. Continue looking this way until we run out of elements in + // 'index'. + // + // We can almost always get a buffer without resorting to this. The only + // exception is for cases where the relevant sub-buffer is truly unknowable, + // for example the sub-buffer of a tuple-shaped select. + ShapeIndex new_index = index; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = buffer_assn.GetUniqueSlice(instr, new_index); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; } } @@ -2036,7 +2020,7 @@ Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { return Unimplemented("Gather is not implemented on GPUs."); } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2260,37 +2244,87 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } -Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, - KernelThunk* thunk) { +StatusOr> IrEmitterUnnested::BuildInitializerThunk( + const HloInstruction* hlo) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || - inst->opcode() == HloOpcode::kReduce); - const HloInstruction* init_value = nullptr; - switch (inst->opcode()) { - case HloOpcode::kSelectAndScatter: - init_value = inst->operand(2); - break; - case HloOpcode::kReduce: - init_value = inst->operand(1); - break; - default: - LOG(FATAL) << "Opcode " << inst->opcode() - << " should not need an initializer."; - } + const HloInstruction* init_value = [&] { + switch (inst->opcode()) { + case HloOpcode::kSelectAndScatter: + return inst->operand(2); + case HloOpcode::kReduce: + return inst->operand(1); + default: + LOG(FATAL) << "Opcode " << inst->opcode() + << " should not need an initializer."; + } + }(); if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } - return EmitTargetElementLoopInThunk( + // In the common case, the initializer is a constant. In this case, emit a + // device-memset call if we can. Currently StreamExecutor only supports + // zeroing and 32-bit memsets. + if (init_value->IsConstant()) { + CHECK(ShapeUtil::IsScalar(init_value->shape())); + int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape()); + const auto& literal = init_value->literal(); + + // Are all the bytes of this scalar equal to 0? If so, we can create a + // MemzeroThunk. + ArraySlice literal_bytes( + reinterpret_cast(literal.untyped_data()), num_bytes); + if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + uint16 pattern16; + if (num_bytes == 1) { + uint8 b = literal_bytes.front(); + pattern16 = uint16{b} | (uint16{b} << 8); + } else { + pattern16 = literal_bytes.front(); + } + uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); + return {MakeUnique(pattern32, + GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(literal_bytes.data(), literal_bytes.data() + 4, + literal_bytes.size() - 4) == 0) { + uint32 word; + memcpy(&word, literal_bytes.data(), sizeof(word)); + return {MakeUnique(word, GetAllocationSlice(*hlo), + hlo)}; + } + } + + // Otherwise fall back to our slow initializer code. + std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, - thunk); + kernel_thunk.get())); + + // Clean up state left behind by emitting the loop above. (This is normally + // done in IrEmitterUnnested::Postprocess().) + bindings_.UnbindAllLocalIrValues(); + + // Convert unique_ptr to StatusOr>. + return {std::move(kernel_thunk)}; } namespace { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b83a2337e2decd9d4fba3d40fcf33f131fca8a3c..66c62e2d2de3ed1668271a21943dc73ed3d77651 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -148,13 +148,10 @@ class IrEmitterUnnested : public IrEmitter { tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reducer); - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -163,6 +160,11 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + // Returns a thunk that, given a reduce or select-and-scatter op, initializes + // its memory to the appropriate initial value. + StatusOr> BuildInitializerThunk( + const HloInstruction* hlo); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index f4c4dcdafd6cc0cd64da5a8d1f23c8c0e7b2a9cb..86c4ac18b0501c38aaaae5a007bddcf261ca338f 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -68,17 +68,3 @@ tf_cc_test( "@llvm//:support", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..18e673542c5b47cb90d31a8eff62a5e4adb78d1d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +namespace se = ::perftools::gputools; + +Status MemzeroThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemZero(&dest_data, dest_data.size()); + return Status::OK(); +} + +Status Memset32BitValueThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemset32(&dest_data, value_, dest_data.size()); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..b4bb74d1dd6dc9d09c5e4d439d57dfe8b57c2ed9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/stream_executor/stream_executor.h" + +// This file contains thunks that set a buffer's elements to a particular value. +// This can be faster than emitting a kernel to set the elements. + +namespace xla { +namespace gpu { + +// Thunk that zeroes out a given chunk of memory. +class MemzeroThunk : public Thunk { + public: + explicit MemzeroThunk(const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemzero, hlo), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice dest_; +}; + +// Thunk that sets a given chunk of memory to a particular 32-bit value. The +// destination chunk must have size divisible by 32 bits. +class Memset32BitValueThunk : public Thunk { + public: + explicit Memset32BitValueThunk(uint32 value, + const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + uint32 value_; + const BufferAllocation::Slice dest_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index fa405b9329a327a70161821212db4d3213e834b7..7bda4e2fcd469bd430e5ef1846251c8504225383 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -69,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - input = CreatePadHlo(input, padding, padding_config).ValueOrDie(); + input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } if (window_util::HasNegativePadding(conv_window)) { @@ -92,8 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::max(0LL, -conv_window.dimensions(i).padding_high()); } - input = CreateSliceHlo(input, start_indices, limit_indices, strides) - .ValueOrDie(); + input = + MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); } return input; @@ -126,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - return CreatePadHlo(kernel, padding, padding_config).ValueOrDie(); + return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -238,7 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = - CreatePadHlo(input, padding, input_padding_config).ValueOrDie(); + MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 2c3032d79be221e8cacb178ffb1817459b603cc0..9eea958d1214b131d49cb4e28f1944860408d3a8 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -51,6 +51,8 @@ class Thunk { kGemm, kInfeed, kKernel, + kMemset32BitValue, + kMemzero, kSequential, kTuple, kWhile, diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index bf903d6a390fe2951d33942dfc2e124868c9fdb5..0b446c654779db410ebbd91ef9a5bab14d08a278 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. -// -// Don't use these protos in the real compilation or execution codepaths. The -// data format is meant for debugging only, and may change without notice. +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the -// corresponding C++ classes. +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. // // FIELD NAMES ARE IMPORTANT // @@ -38,16 +37,19 @@ option cc_enable_arenas = true; message HloInstructionProto { reserved 10; reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; string name = 1; string opcode = 2; xla.Shape shape = 3; - // TODO(b/67782397): Replace instruction names with HloInstruction ids. - repeated string operand_names = 4; - repeated string control_predecessor_names = 5; - repeated string called_computation_names = 6; - xla.OpMetadata metadata = 7; // Literal, only present for kConstant. @@ -58,7 +60,6 @@ message HloInstructionProto { // Fusion state, only present for kFusion. string fusion_kind = 11; - HloComputationProto fused_instructions_computation = 12; // Index for kGetTupleElement. int64 tuple_index = 13; @@ -136,30 +137,40 @@ message HloInstructionProto { // The id of this instruction. int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; + + xla.OpSharding sharding = 40; } // Serialization of HloComputation. message HloComputationProto { + reserved 3; + reserved "root_name"; + string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; - // The name of the root of the computation. - string root_name = 3; - // The program shape (with layout) of this computation. xla.ProgramShape program_shape = 4; // The id of this computation. int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; + int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 30e32a46d7dd0923f738939c33407ac7484b5bbe..a88283ed9a6459b4fa9310e160b59c77d51f1027 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -171,24 +171,21 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } - // Compute and return a vector of buffers that the given value must be - // contained in due to HLO aliasing rules. - std::vector ComputeAliasedBuffers(const HloValue& value) { + void ComputeWhileAliasedBuffers(const HloValue& value, + std::vector* aliased_buffers) { + VLOG(3) << "Compute kWhile aliases"; // Value is init of a while (use is while). - std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = dataflow_.GetUniqueValueAt(use.instruction, use.operand_index); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); VLOG(3) << " value is init value to a while; must share buffer with " "while value " << while_value.ToShortString(); } } - // Value is a parameter of a while body/condition. if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { const HloComputation* computation = @@ -205,11 +202,10 @@ class BufferValueMap { VLOG(3) << " value is parameter value of the body or condition of a " "while; must share buffer with while value " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } - // Value is the root of a while body. for (const HloPosition& position : value.positions()) { const HloComputation* computation = position.instruction->parent(); @@ -224,27 +220,71 @@ class BufferValueMap { const HloValue& while_value = dataflow_.GetUniqueValueAt( callsite.instruction(), position.index); - VLOG(3) << " value is root the body computation of a while; must " - "share buffer with while value " + VLOG(3) << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; body root and while value root must share buffer " + "among them : " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } } - // Value is the output of the while instruction itself. if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { VLOG(3) << " value is output of a while instruction"; - aliased_buffers.push_back(GetBufferForValue(value)); + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + + void ComputeConditionalAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute kConditional aliases"; + // Aliases the buffers of the true/false computations roots, with the one of + // the conditional. + for (const HloPosition& position : value.positions()) { + const HloComputation* computation = position.instruction->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + if (position.instruction == computation->root_instruction()) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& cond_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), position.index); + VLOG(3) + << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; true/false branch roots must share buffer among them : " + << cond_value.ToShortString(); + aliased_buffers->push_back(GetBufferForValue(cond_value)); + } + } + } + } + // Value is the output of the conditional instruction itself. + if (value.defining_instruction()->opcode() == HloOpcode::kConditional) { + VLOG(3) << " value is output of a conditional instruction"; + aliased_buffers->push_back(GetBufferForValue(value)); } + } + // Compute and return a vector of buffers that the given value must be + // contained in due to HLO aliasing rules. + std::vector ComputeAliasedBuffers(const HloValue& value) { + for (const HloUse& use : value.uses()) { + VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; + } + std::vector aliased_buffers; + ComputeWhileAliasedBuffers(value, &aliased_buffers); + ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. std::sort(aliased_buffers.begin(), aliased_buffers.end()); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); - return aliased_buffers; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index f99c7cf5e495eaf83e0dda859ef31a7487bc6ffe..594413e88fb26e86b198d08b2e4db77fad671348 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -65,6 +65,7 @@ HloComputation::HloComputation( std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), + unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); @@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->UniquifyName(&parent()->instruction_name_uniquer()); instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } - Reparent(instruction.get()); + instruction->set_parent(this); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); @@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -void HloComputation::Reparent(HloInstruction* instruction) { - instruction->set_parent(this); -} - bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -307,19 +304,15 @@ void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, std::list* post_order) { - if (visited->count(computation) > 0) { - return; - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); + if (visited->insert(computation).second) { + for (auto* instruction : computation->instructions()) { + for (HloComputation* called_computation : + instruction->called_computations()) { + ComputeComputationPostOrder(called_computation, visited, post_order); + } } + post_order->push_back(computation); } - - visited->insert(computation); - post_order->push_back(computation); } } // namespace @@ -393,12 +386,16 @@ string HloComputation::ToString(const HloPrintOptions& options) const { HloComputationProto HloComputation::ToProto() const { HloComputationProto proto; + CHECK(unique_id_ != -1) + << "This computation does not have a valid id. Please make sure the " + "computation is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } - proto.set_root_name(root_instruction()->name()); + proto.set_root_id(root_instruction()->unique_id()); *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } @@ -406,31 +403,29 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction) { + const tensorflow::gtl::FlatMap& computation_map) { std::vector> instructions; - tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } - TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); - instruction_map[instruction->name()] = instruction.get(); + TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); + instruction_map[instruction_proto.id()] = instruction.get(); instructions.push_back(std::move(instruction)); } - TF_RET_CHECK(!proto.root_name().empty()); - TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); - HloInstruction* root = instruction_map.at(proto.root_name()); - return WrapUnique(new HloComputation( - proto.name(), parameter_count, &instructions, root, fusion_instruction)); + TF_RET_CHECK(proto.root_id() != -1); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); + HloInstruction* root = instruction_map.at(proto.root_id()); + return WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dd9d346999f0eae448d74628278c802ccd3f51b4..9d3f6e9a2c2efd97681a22b6b0f6d929afc553de 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -160,20 +160,12 @@ class HloComputation { // module: the module which will contain the computation. The newly created // computation is *not* added to the module, however. // proto: the proto to convert from. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used only when the instruction is a fusion instruction. - // fusion_instruction: if non-null then the newly created computation will - // be constructed as a fused computation with this instruction as its - // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction = nullptr); + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // @@ -342,6 +334,15 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // The id of this computation should be unique within the module. + void SetUniqueId(int64 id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + unique_id_ = id; + } + + int64 unique_id() const { return unique_id_; } + private: explicit HloComputation( const string& name, int parameter_count, @@ -352,10 +353,6 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Helper for setting the parent of instructions that are added to this - // computation. - void Reparent(HloInstruction* instruction); - // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. @@ -373,6 +370,7 @@ class HloComputation { std::vector CollectUnreachableRoots() const; string name_; + int64 unique_id_; HloInstruction* root_instruction_; // If this computation is a fusion computation, this field points to the diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 4ec2ef27bf59b0c877ec38e55ef5c12debeec227..44e4f75f75b275653e1a07111943843fc6f78b33 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -379,20 +380,101 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { } Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { - auto rhs_instruction = convolution->operand(1); + auto lhs = convolution->operand(0); + auto rhs = convolution->operand(1); + Window window = convolution->window(); + const auto& result_shape = convolution->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const auto& dnums = convolution->convolution_dimension_numbers(); - const int64 output_features = - convolution->shape().dimensions(dnums.output_feature_dimension()); - - // For each output element, we do one fma per element in the kernel at some - // given output feature index. - const int64 fmas_per_output_element = - output_features > 0 - ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features - : 0; - const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_properties_[kFlopsKey] = - output_elements * fmas_per_output_element * kFmaFlops; + + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_feature_dim = dnums.input_feature_dimension(); + const int64 output_feature_dim = dnums.output_feature_dimension(); + const int64 input_feature = + ShapeUtil::GetDimension(lhs_shape, input_feature_dim); + const int64 output_feature = + ShapeUtil::GetDimension(result_shape, output_feature_dim); + const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + DimensionVector kernel_limits; + DimensionVector output_limits; + DimensionVector input_limits; + if (window.dimensions().empty()) { + window = window_util::MakeWindow({1}); + kernel_limits.push_back(1); + output_limits.push_back(1); + input_limits.push_back(1); + } else { + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + // Spatial dimension number for kernel (rhs). + const int64 kernel_spatial_dim = + dnums.kernel_spatial_dimensions(spatial_dimension); + const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim); + kernel_limits.push_back(kernel_limit); + + // Spatial dimension number for output. + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(spatial_dimension); + const int64 output_limit = result_shape.dimensions(output_spatial_dim); + output_limits.push_back(output_limit); + + // Spatial dimension number for input (lhs). + const int64 input_spatial_dim = + dnums.input_spatial_dimensions(spatial_dimension); + const int64 input_limit = lhs_shape.dimensions(input_spatial_dim); + input_limits.push_back(input_limit); + } + } + + DimensionVector valid_position_counts; + + // Loop over each spatial dimension. + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + int64 valid_position_count = 0; + // Loop over each point in the kernel. + for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension]; + ++kernel_idx) { + // Loop over each point in the output. + for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension]; + ++output_idx) { + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(spatial_dimension); + const int64 undilated_index = output_idx * window_dim.stride() - + window_dim.padding_low() + + kernel_idx * window_dim.window_dilation(); + + // Calculate the actual lhs (input) index after dilation. Avoid the + // division as an optimization. + const int64 lhs_spatial_index = + window_dim.base_dilation() > 1 + ? undilated_index / window_dim.base_dilation() + : undilated_index; + + // Skip if the lhs (input) index is to be dilated. + if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) { + continue; + } + + // Skip if input index is not in bound. + if (lhs_spatial_index < 0 || + lhs_spatial_index >= input_limits[spatial_dimension]) { + continue; + } + + valid_position_count += 1; + } + } + valid_position_counts.push_back(valid_position_count); + } + + const int64 fma_count = + input_feature * output_feature * batch * Product(valid_position_counts); + current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3b289c240a45e8f3df8156ed89e879da2132d01a..3d055b327ee920dac9c0904c69e1461206b31203 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -186,12 +186,14 @@ TEST_F(HloCostAnalysisTest, Map) { TEST_F(HloCostAnalysisTest, Convolution) { ComputationBuilder builder(client_, "convolution"); auto input = builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, - /*x_dim=*/20}), + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), "input"); auto kernel = builder.Parameter( - 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, - /*x_dim=*/3}), + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), "kernel"); auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); @@ -440,5 +442,32 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { + ComputationBuilder builder(client_, "BaseDilatedConvolution"); + auto input = builder.Parameter( + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = builder.Parameter( + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + + auto result = builder.ConvGeneralDilated( + input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(), 1472); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 4585bffa42fd001c7375b778c5e1c42e58d17692..b186767ce792cd89ae77fe9a03b3a2ecf296b804 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -23,8 +23,8 @@ namespace xla { using tensorflow::gtl::ArraySlice; using tensorflow::strings::StrCat; -StatusOr CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs, - HloInstruction* rhs) { +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape binary_op_shape, @@ -33,9 +33,9 @@ StatusOr CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); } -StatusOr CreatePadHlo(HloInstruction* operand, - HloInstruction* padding_value, - const PaddingConfig& padding_config) { +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, padding_value->parent()); TF_ASSIGN_OR_RETURN( @@ -46,10 +46,10 @@ StatusOr CreatePadHlo(HloInstruction* operand, pad_shape, operand, padding_value, padding_config)); } -StatusOr CreateSliceHlo(HloInstruction* operand, - ArraySlice start_indices, - ArraySlice limit_indices, - ArraySlice strides) { +StatusOr MakeSliceHlo(HloInstruction* operand, + ArraySlice start_indices, + ArraySlice limit_indices, + ArraySlice strides) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -58,7 +58,7 @@ StatusOr CreateSliceHlo(HloInstruction* operand, slice_shape, operand, start_indices, limit_indices, strides)); } -StatusOr CreateConvolveHlo( +StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) { HloComputation* computation = lhs->parent(); @@ -70,8 +70,8 @@ StatusOr CreateConvolveHlo( convolve_shape, lhs, rhs, window, dimension_numbers)); } -StatusOr CreateTransposeHlo(HloInstruction* operand, - ArraySlice dimensions) { +StatusOr MakeTransposeHlo(HloInstruction* operand, + ArraySlice dimensions) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -80,23 +80,23 @@ StatusOr CreateTransposeHlo(HloInstruction* operand, HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); } -StatusOr CreateReshapeHlo(const Shape& result_shape, - HloInstruction* operand) { +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand) { HloComputation* computation = operand->parent(); return computation->AddInstruction( HloInstruction::CreateReshape(result_shape, operand)); } -StatusOr CreateReshapeHlo( +StatusOr MakeReshapeHlo( ArraySlice result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_dim_bounds); - return CreateReshapeHlo(new_shape, operand); + return MakeReshapeHlo(new_shape, operand); } -StatusOr CreateDynamicSliceHlo(HloInstruction* operand, - HloInstruction* start_indices, - ArraySlice slice_sizes) { +StatusOr MakeDynamicSliceHlo(HloInstruction* operand, + HloInstruction* start_indices, + ArraySlice slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); TF_ASSIGN_OR_RETURN( @@ -107,7 +107,7 @@ StatusOr CreateDynamicSliceHlo(HloInstruction* operand, dynamic_slice_shape, operand, start_indices, slice_sizes)); } -StatusOr CreateDynamicUpdateSliceHlo( +StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { HloComputation* computation = operand->parent(); @@ -121,7 +121,7 @@ StatusOr CreateDynamicUpdateSliceHlo( dynamic_update_slice_shape, operand, update, start_indices)); } -StatusOr CreateBroadcastHlo( +StatusOr MakeBroadcastHlo( HloInstruction* operand, ArraySlice broadcast_dimensions, ArraySlice result_shape_bounds) { HloComputation* computation = operand->parent(); @@ -132,8 +132,8 @@ StatusOr CreateBroadcastHlo( broadcast_shape, operand, broadcast_dimensions)); } -StatusOr CreateGetTupleElementHlo(HloInstruction* operand, - int64 index) { +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( @@ -143,8 +143,8 @@ StatusOr CreateGetTupleElementHlo(HloInstruction* operand, HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); } -StatusOr CreateConcatHlo(ArraySlice operands, - int64 dimension) { +StatusOr MakeConcatHlo(ArraySlice operands, + int64 dimension) { CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); @@ -181,7 +181,7 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); - return CreateReshapeHlo(output_shape, operand); + return MakeReshapeHlo(output_shape, operand); } StatusOr ExpandFirstDimIntoNDims( @@ -198,25 +198,7 @@ StatusOr ExpandFirstDimIntoNDims( std::back_inserter(expanded_shape_dim_bounds)); Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), expanded_shape_dim_bounds); - return CreateReshapeHlo(new_shape, operand); -} - -StatusOr ExpandLastDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { - CHECK_GT(operand->shape().dimensions_size(), 0); - CHECK_EQ(operand->shape().dimensions(operand->shape().dimensions_size() - 1), - Product(expanded_dims)); - - std::vector expanded_shape_dim_bounds; - expanded_shape_dim_bounds.reserve(expanded_dims.size() + - operand->shape().dimensions_size() - 1); - std::copy(operand->shape().dimensions().begin(), - operand->shape().dimensions().end() - 1, - std::back_inserter(expanded_shape_dim_bounds)); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); - Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), - expanded_shape_dim_bounds); - return CreateReshapeHlo(new_shape, operand); + return MakeReshapeHlo(new_shape, operand); } StatusOr ElideDegenerateDims(HloInstruction* operand, @@ -241,7 +223,7 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); - return CreateReshapeHlo(output_shape, operand); + return MakeReshapeHlo(output_shape, operand); } StatusOr PadVectorWithZeros(HloInstruction* operand, @@ -258,7 +240,7 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(operand->shape().element_type())))); - return CreatePadHlo(operand, zero, padding_config); + return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( @@ -267,8 +249,8 @@ StatusOr BroadcastZeros( HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - return CreateBroadcastHlo(zero, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/broadcast_dimensions); + return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/broadcast_dimensions); } StatusOr> CreateComputationWithSignature( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 2b03a849cff35008a96eaedd212ab1aa24695822..d99e32a737e6aaa2ff746cf6c00d4300cf62f4e1 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -28,73 +28,73 @@ namespace xla { // Creates a binary HLO instruction and adds it to the computation containing // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). -StatusOr CreateBinaryHlo(HloOpcode opcode, HloInstruction* lhs, - HloInstruction* rhs); +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs); // Creates a pad HLO instruction and adds it to the computation containing // `operand` and `padding_value` (`operand` and `padding_value` must be in the // same computation). -StatusOr CreatePadHlo(HloInstruction* operand, - HloInstruction* padding_value, - const PaddingConfig& padding_config); +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr CreateSliceHlo( +StatusOr MakeSliceHlo( HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). -StatusOr CreateConvolveHlo( +StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr CreateTransposeHlo( +StatusOr MakeTransposeHlo( HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. -StatusOr CreateReshapeHlo(const Shape& result_shape, - HloInstruction* operand); +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand); -StatusOr CreateReshapeHlo( +StatusOr MakeReshapeHlo( tensorflow::gtl::ArraySlice result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). -StatusOr CreateDynamicSliceHlo( +StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes); // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and // `start_indices` must be in the same computation). -StatusOr CreateDynamicUpdateSliceHlo( +StatusOr MakeDynamicUpdateSliceHlo( HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices); // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. -StatusOr CreateBroadcastHlo( +StatusOr MakeBroadcastHlo( HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions, tensorflow::gtl::ArraySlice result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. -StatusOr CreateGetTupleElementHlo(HloInstruction* operand, - int64 index); +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index); // Creates a Concatenate HLO instruction and adds it to the computation // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). -StatusOr CreateConcatHlo( +StatusOr MakeConcatHlo( tensorflow::gtl::ArraySlice operands, int64 dimension); // ----------------------------------------------------------------------------- @@ -119,16 +119,6 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); StatusOr ExpandFirstDimIntoNDims( HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); -// Expands (via reshape) the last (logical) dimension of `operand` into a -// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 -// and the number of elements in its last dimension must be equal to the -// product of `expanded_dims`. -// -// For instance if `operand` has shape f32[9,7,200] and expanded_dims is -// {2,5,20} the result is `operand` reshaped to [9,7,2,5,20]. -StatusOr ExpandLastDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); - // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in // `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 279edd4ba8772a9c576f76f554de8ec68631b953..cd7cbbdd71706fddb64855f631eb09de35da52e8 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -109,6 +109,11 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { + continue; + } + // An instruction is considered to be equivalent to another only if they // share the exact same set of operands. So to find equivalent // instructions, we just search among instructions which share operand(0) @@ -118,7 +123,7 @@ StatusOr HloCSE::Run(HloModule* module) { tensorflow::gtl::InlinedVector equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && + if (user != instruction && !user->HasSideEffect() && user->Identical(*instruction, eq_instructions, eq_computations, is_layout_sensitive_)) { equivalent_instructions.push_back(user); diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 3601a790c4428ee39c264b217a4b9a991ad8456c..df8853f34f6a72c52d1cde7332ada3809d2f3d96 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -414,8 +414,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { EXPECT_THAT(root, op::Add(rng1, rng2)); } -// TODO(b/28245743): Handle impure functions correctly in CSE. -TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { +TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. @@ -458,14 +457,16 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Add(op::Map(), op::Map())); + VLOG(3) << "before: " << module->ToString(); + HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + VLOG(3) << "after: " << module->ToString(); EXPECT_EQ(4, computation->instruction_count()); root = computation->root_instruction(); - auto operand = root->operand(0)->operand(0); - EXPECT_THAT(operand, op::Map()); - EXPECT_THAT(root, op::Add(operand, operand)); + EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 934e43ba4879628362009267c671ec4cb0d79c52..0c37a8d75f38dabaad886cc9d4adce8ab29ddf18 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -368,11 +368,11 @@ bool HloDataflowAnalysis::UpdateConditionalValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( conditional->false_computation()->root_instruction())}; - // A phi-node is not defined for a kConditional instruction even though it - // represents a join point. This is because the current approach is to define - // a phi-node only for kWhile to account for the dataflow through back-edges - // and deal with the ambiguity in other cases. - return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + if (ssa_form_) { + return Phi(conditional, inputs); + } else { + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + } } bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 7bf3a1a06045c79621d75b653bf42220705a69d4..07f69b8e1339fed636e4eb54791941b85e09fd17 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1602,11 +1602,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), ElementsAre(HloUse{conditional, 2, {}})); - EXPECT_EQ(analysis.values().size(), 3); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), - analysis.GetValueDefinedAt(constant2))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 4); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + } } TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { @@ -1713,11 +1719,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); - EXPECT_EQ(analysis.values().size(), 6); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(add), - analysis.GetValueDefinedAt(sub))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 7); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); + } } TEST_P(HloDataflowAnalysisTest, NestedConditionals) { @@ -1834,20 +1846,27 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), analysis.GetValueDefinedAt(constant2)); - EXPECT_EQ(analysis.values().size(), 9); - EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT( - HloValuesAt(inner_conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()))); - EXPECT_THAT( - HloValuesAt(conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()), - analysis.GetValueDefinedAt(computation3->root_instruction()))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 11); + EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); + } } INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 91341b5d35d85b904715fb5a059f51fff13ac4da..9d7251b6ae94c8ffd14db980f18df077c9767ae7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1520,14 +1520,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { arg_dim_counts[dim] = arg_dimensions[dim]; } - // Create mapping from result index to arg index. - const int64 result_rank = ShapeUtil::Rank(result->shape()); - int64 result_dim = 0; - std::vector result_to_arg_index(result_rank); + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; for (int64 i = 0; i < arg_dimensions.size(); ++i) { if (arg_dim_steps[i] == 0) { - result_to_arg_index[result_dim] = i; - ++result_dim; + result_to_arg_index.push_back(i); } } @@ -1542,6 +1540,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { base[result_to_arg_index[i]] = multi_index[i]; } + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } auto func = [&](ArraySlice input_index) { auto curr_val = arg_literal.Get(input_index); @@ -1554,19 +1566,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::unique_ptr computed_result = embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on + // Clear visit states so that we can use the evaluator again on // the same computation. embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. result_val = computed_result->Get({}); - return true; }; - + // Computes one element of the result, reducing all dimensions that + // contribute to that element. ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); - return result_val; })); @@ -1574,6 +1584,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { auto operand = select_and_scatter->operand(0); auto source = select_and_scatter->operand(1); @@ -2771,6 +2795,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), /*output_shape=*/shape); + const Shape& operand_shape = operand.shape(); + auto gather_inner_loop_body = [&](ArraySlice output_window_index, ArraySlice input_gather_index, @@ -2780,9 +2806,16 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { output_window_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; + DCHECK_LT(output_index[i], shape.dimensions(i)); } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_gather_index[i] + input_window_index[i]; + // TODO(b/74360564): We should implement whatever out of bounds behavior + // we decide for dynamic-slice here as well. + input_index[i] = (input_gather_index[i] + input_window_index[i]) % + operand_shape.dimensions(i); + if (input_index[i] < 0) { + input_index[i] += operand_shape.dimensions(i); + } } TF_RETURN_IF_ERROR( result->CopyElementFrom(operand, input_index, output_index)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 685cacd7f74c00789296dee16f0a6a94c35a4393..dd14dd38537a83d0ee16cff9e3c22a38f544e208 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1205,6 +1206,80 @@ TEST_P(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; + +// Tests that Reduce doesn't lose precision when adding many numbers (because +// it accumulates its result in a double). +TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + HloComputation::Builder b(TestName()); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction* init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module().AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + std::unique_ptr result = + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, *result); +} + +// Reducing many numbers should be fast because it doesn't create +// intermediate Literals; the microbenchmark should finish in < 1 msec. +void BM_ReducePrecisely(int num_iters) { + tensorflow::testing::StopTiming(); + HloComputation::Builder b("BM_ReducePrecisely"); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module.AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + tensorflow::testing::StartTiming(); + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_ReducePrecisely); + TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index f0df93b61d29c1535d8a89fbd65e669de5b43729..c3ccbf0f0c75b569b49652807dea52faebdccc31 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -111,8 +111,8 @@ HloExecutionProfile::HloExecutionProfile( : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( - /*count*/ hlo_profile_index_map_.total_count(), - /*value*/ 0) {} + /*count=*/hlo_profile_index_map_.total_count(), + /*value=*/0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1dc72355cf179e996caab4d6b52068dc99d02244..25702dc65ea1ebd9d91b3382dcb909e606628202 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -823,7 +823,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { + if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -1041,8 +1041,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::StringPiece(instr->name()) - .starts_with(HloOpcodeString(instr->opcode()))) { + if (tensorflow::str_util::StartsWith(instr->name(), + HloOpcodeString(instr->opcode()))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d33add23d07b52cb56e4b212a29b415259af7694..fcf9ebf5f787445f5e89f126e9f2393fd3bd1790 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -51,24 +52,22 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation) { + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const string& operand_name : proto.operand_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) - << "No instruction named " << operand_name; - instruction->AppendOperand(instruction_map.at(operand_name)); - } - for (const string& predecessor_name : proto.control_predecessor_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) - << "No instruction named " << predecessor_name; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) ->AddControlDependencyTo(instruction.get())); } @@ -76,26 +75,36 @@ StatusOr> HloInstruction::CreateFromProto( // HloInstructionProto and do not appear as an HloComputationProto within the // HloModuleProto. if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + fused_computation->SetFusionInstruction(instruction.get()); + instruction->called_computations_.push_back(fused_computation); } else { - for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) - << "No computation named " << computation_name; + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; instruction->called_computations_.push_back( - computation_map.at(computation_name)); + computation_map.at(computation_id)); } } + if (instruction->opcode() == HloOpcode::kTrace) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "Trace instruction should have 1 operand but sees " + << instruction->operands().size(); + instruction->mutable_operand(0)->set_tracing(instruction.get()); + } + TF_RET_CHECK(!proto.name().empty()); instruction->name_ = proto.name(); @@ -168,6 +177,7 @@ StatusOr> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_ = Literal::CreateR1U8(tag); + operand->set_tracing(instruction.get()); return instruction; } @@ -2313,14 +2323,18 @@ string HloInstruction::ToShortString() const { HloInstructionProto HloInstruction::ToProto() const { HloInstructionProto proto; + CHECK(unique_id_ != -1) + << "This instruction does not have a valid id. Please make sure the " + "instruction is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); *proto.mutable_shape() = shape_; for (const HloInstruction* operand : operands_) { - *proto.add_operand_names() = operand->name(); + proto.add_operand_ids(operand->unique_id()); } for (const HloInstruction* control : control_predecessors_) { - *proto.add_control_predecessor_names() = control->name(); + proto.add_control_predecessor_ids(control->unique_id()); } *proto.mutable_metadata() = metadata_; @@ -2330,11 +2344,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); } else { for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); + proto.add_called_computation_ids(computation->unique_id()); } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e4c86214c2014095b2e171ff10691e1221574cb7..80f84082442798d240a0a8e11d85ceaf638a4695 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -179,20 +179,15 @@ class HloInstruction { // module: the module which will contain the instruction. The newly created // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. - // instruction_map: a map from instruction name to HloInstruction*. This map + // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation); + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -933,6 +928,13 @@ class HloInstruction { const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } + // Returns the sharding unique device, if any. + tensorflow::gtl::optional sharding_unique_device() const { + if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + return tensorflow::gtl::optional(); + } + return sharding_->UniqueDevice().ValueOrDie(); + } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cdea3d597824d155241a544d226aa18d3b0b0274..08b9a29aeda2ee612d49b0788acf8438a25eb6a3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal( for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); + computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -204,14 +209,11 @@ string HloModule::ToString(const HloPrintOptions& options) const { HloModuleProto HloModule::ToProto() const { HloModuleProto proto; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); + proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are added when the fusion instructions are created by - // HloInstruction::CreateFromProto. - if (computation->IsFusionComputation()) { - continue; - } HloComputationProto computation_proto = computation->ToProto(); if (computation->name() == entry_computation_->name()) { *proto.mutable_program_shape() = computation_proto.program_shape(); @@ -235,8 +237,8 @@ StatusOr> HloModule::CreateFromProto( for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = module_config.entry_computation_layout().parameter_layout(i).shape(); - TF_RET_CHECK( - ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), + parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( @@ -245,7 +247,8 @@ StatusOr> HloModule::CreateFromProto( } const Shape& result_shape = module_config.entry_computation_layout().result_layout().shape(); - TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + TF_RET_CHECK( + ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) @@ -254,26 +257,20 @@ StatusOr> HloModule::CreateFromProto( auto module = MakeUnique(proto.name(), entry_computation_handle, module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); - TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); - string computation_name = computation->name(); + int64 computation_id = computation_proto.id(); + TF_RET_CHECK(computation_id != -1); + TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_name] = module->AddComputationInternal( + computation_map[computation_id] = module->AddComputationInternal( std::move(computation), - /*is_entry=*/proto.entry_computation_name() == computation_name, + /*is_entry=*/proto.entry_computation_id() == computation_id, /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -283,10 +280,6 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); @@ -302,12 +295,13 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( - const HloModuleProto& module) { + const HloModuleProto& module, const DebugOptions& debug_options) { TF_RET_CHECK(module.has_program_shape()) << "No program shape found in the proto"; const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); + module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 755bbd359f7b95e7f3f3cbee1b46df85908202c6..9f7f25202ba42b14e995ed5c47d1012dabc69332 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -172,7 +172,7 @@ class HloModule { // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( - const HloModuleProto& module); + const HloModuleProto& module, const DebugOptions& debug_options); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index fa5dcb0b369d17c70c64c67b9f11640c93fb4278..54c34ce116651608e6d91cdcba9c708ca3a5f75e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -313,6 +313,27 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } + const HloModule* send_module = channel.send->parent()->parent(); + const HloModule* send_done_module = channel.send_done->parent()->parent(); + if (send_module != send_done_module) { + return FailedPrecondition( + "send and send-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + } + const HloModule* recv_module = channel.recv->parent()->parent(); + const HloModule* recv_done_module = channel.recv_done->parent()->parent(); + if (recv_module != recv_done_module) { + return FailedPrecondition( + "recv and recv-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + } + if (send_module == recv_module) { + return FailedPrecondition( + "send and recv (channel=%lld) must be on different devices: %lld", + channel.id, GetModuleId(send_module)); + } } // Check if channel instructions are used only in allowed computations. diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 1b24d8da9e832e6847cb6f405e15af3c455f695a..e89d94bede6c437ca1131a1b1b0098390d58c0d9 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -66,6 +66,28 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } } + // If the common ancestor is a conditional instruction, even though the true + // and false computations are not really ordered per-se, we define the true + // computation to be ordered before the false one. + // This ensures that buffers can still be shared among the two computations + // as they will forcibly have disjoint liveness. + if (a_ancestor == b_ancestor && + a_ancestor->opcode() == HloOpcode::kConditional) { + const HloComputation* true_computation = a_ancestor->true_computation(); + const HloComputation* false_computation = a_ancestor->false_computation(); + if (call_graph_->InstructionIsNestedIn(a, true_computation) && + call_graph_->InstructionIsNestedIn(b, false_computation)) { + return true; + } + // If 'b' is the conditional ancestor, and 'a' is within the true or false + // computations, 'a' executes before 'b'. + if (b == a_ancestor && + (call_graph_->InstructionIsNestedIn(a, true_computation) || + call_graph_->InstructionIsNestedIn(a, false_computation))) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -118,7 +140,18 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { b.defining_instruction()->while_condition()))) { return true; } - + // If 'b' is a conditional phi and 'a' is in the true or false computation, + // then 'a' executes before 'b'. + if (b.is_phi() && + b.defining_instruction()->opcode() == HloOpcode::kConditional && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->true_computation()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->false_computation()))) { + return true; + } return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); } @@ -212,18 +245,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; if (!IsDefinedBefore(a, b)) { - VLOG(4) << "a not defined before b"; + VLOG(4) << a << " not defined before " << b; return false; } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of a (" << use << ") not before b is defined"; + VLOG(4) << "use of " << a << " (" << use << ") not before " << b + << " is defined"; return false; } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index a989fce63234cb860d08c48b02462e96bec879bc..37a7fbad97cea2f34798efecc2489e57d1374f35 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -34,53 +34,6 @@ namespace { class HloOrderingTest : public HloTestBase {}; -TEST_F(HloOrderingTest, LastUseScheduledFirst) { - // Tests scheduling of the following HLO code: - // - // %ab = abs(%param) - // %exp = exp(%param) - // %add = add(%ab, %exp) - // %negate = negate(%exp) - // %sub = subtract(%add, %negate) - // - // %add should be scheduled before %negate because %add is the last (and only) - // use of %ab. Scheduling %add first then frees up %ab's buffer. - const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); - auto ab = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); - auto sub = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // Tests the ordering of instructions in different computations using the // following HLO code: @@ -362,5 +315,66 @@ ENTRY while.v11 { ordering.ToString(); // Shouldn't crash. } +TEST_F(HloOrderingTest, ConditionalInstructionOrdering) { + const char* module_str = R"( +HloModule test_conditional_module + +true_branch { + param.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1) +} + +false_branch { + param.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1 + add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4) + ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4) +} + +ENTRY root { + param.3 = (pred[], (s32[], s32[])) parameter(0) + pred.1 = pred[] get-tuple-element(param.3), index=0 + cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1 + conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch + cond_res.1 = s32[] get-tuple-element(conditional), index=0 + cond_res.2 = s32[] get-tuple-element(conditional), index=1 + add.3 = s32[] add(cond_res.1, cond_res.2) + ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Even though the true and false branches has no ordering, since they do not + // interfere (as they are mutually exclusive), we define the true computation + // to be before the false one. + // Similarly, any instruction in the true or false branches are considered + // before the conditional instruction. The roots are effectively "at the same + // time" WRT the conditional, but they are Phi-ed anyway. + HloInstruction* add_1 = FindInstruction(module.get(), "add.1"); + HloInstruction* add_2 = FindInstruction(module.get(), "add.2"); + HloInstruction* add_3 = FindInstruction(module.get(), "add.3"); + HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_2))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_3))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(add_3))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 98b8d34be1f331aaeac94e952deeae1e76379861..b0632448933df4b7681a0704c58d697b5ec68a1f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1320,7 +1320,7 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SchedulerAlgorithm scheduler_algorithm, + MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { HloRematerialization remat(scheduler_algorithm, size_function); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 52553439033a3bcfa4b472f13f9cd4b1ecf5ed96..2ee2dd0571ae8c6604e4ca722351fd48a913bda5 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -66,12 +66,12 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, const ShapeSizeFunction& size_function) : scheduler_algorithm_(scheduler_algorithm), size_function_(size_function) {} @@ -108,7 +108,7 @@ class HloRematerialization { const HloInstruction* instruction) const; // Selects an algorithm to use for HLO scheduling. - SchedulerAlgorithm scheduler_algorithm_; + MemorySchedulerAlgorithm scheduler_algorithm_; // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 1b7d26dde501a6a0955d62ea0938e0683a32d49d..83de54f3fa56ee660b79d8c366dbc0b52f9fde87 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -162,7 +162,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/14 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -195,7 +195,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/20 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -236,7 +236,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/17 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -272,7 +272,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/15 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -314,7 +314,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/13 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), SchedulerAlgorithm::kAuto, &sequence)); + module.get(), DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -480,7 +480,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -577,7 +577,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index e5b1c2efa3fc25d23531df298e125521c002dba1..ec7d8210a70ad7498f77fe807abd53544d4b0487 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -52,10 +52,9 @@ namespace { // Creates an HloModule from the given proto. StatusOr> HloProtoToModule( const HloProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN( - HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto.hlo_module())); - config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module(), + debug_options)); TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(proto.hlo_module(), config)); return std::move(module); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index da448ed71ab470e0c4d72e234bf1f1087d3ea7b4..1a767628f6e2d33df353366974fb866e89f0df5a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -103,10 +103,11 @@ class ListScheduler { for (auto* instruction : computation.instructions()) { tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - instr_uses.insert(buffer); - } + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + instr_uses.insert(buffers.begin(), buffers.end()); + }); } buffer_uses_[instruction] = std::vector( instr_uses.begin(), instr_uses.end()); @@ -339,7 +340,33 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> RunDFSMemoryScheduler( +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function); +} + +} // namespace + +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { @@ -396,32 +423,17 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } -StatusOr MinimumMemoryForComputation( +StatusOr> ListMemoryScheduler( const HloComputation& computation, - const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; + return ListScheduler::Run(computation, points_to_analysis, size_function); } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm == SchedulerAlgorithm::kListSchedule) { - return ListScheduler::Run(computation, points_to_analysis, size_function); - } - if (algorithm == SchedulerAlgorithm::kDfsSchedule) { - return RunDFSMemoryScheduler(computation, points_to_analysis, - size_function); - } - + const LogicalBuffer::SizeFunction& size_function) { // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -431,7 +443,7 @@ StatusOr> CreateMemoryMinimizingSequence( // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, @@ -440,7 +452,7 @@ StatusOr> CreateMemoryMinimizingSequence( TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + DFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, @@ -458,12 +470,10 @@ StatusOr> CreateMemoryMinimizingSequence( } } -} // namespace - StatusOr CreateMemoryMinimizingSequence(const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -479,7 +489,7 @@ CreateMemoryMinimizingSequence(const HloModule& module, StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 1d1eb1e064f75c2220b39e84b010e720a0c37880..068e68383deb170ded1c9b09a8b7ceb8c4c0ab4b 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,28 +34,48 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); -enum class SchedulerAlgorithm { - kListSchedule, - kDfsSchedule, +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function>( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&)> + MemorySchedulerAlgorithm; - // Selects the available scheduler algorithm that had the minimum memory in - // the resulting sequence (a la MinimumMemoryForSequence). - kAuto, -}; +// List scheduler +StatusOr> ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// DFS-order scheduler +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const MemorySchedulerAlgorithm& algorithm = {}); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 7fb338e7042ce19ac9647e23719e738f3ef42c7c..74544c4a67a819d341056aba4cf6b321a5a86c0a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -89,5 +90,105 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); } +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : + sequence.at(module->entry_computation())) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), + sequence.at(module->entry_computation()).front()); + EXPECT_EQ(instructions_by_name.at("result"), + sequence.at(module->entry_computation()).back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index aa9ff89e983aa5d35a18906afca1c6e8eeaefa06..e8e45f1ee968992901988e8b85d4e9ae28f2abe9 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -20,6 +20,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; HloSharding HloSharding::AssignDevice(int64 device_id) { @@ -57,8 +58,9 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", - "devices=", VectorString(tile_assignment_), "}"); + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", + Join(tile_assignment_.dimensions(), ","), "]", + Join(tile_assignment_, ","), "}"); } } @@ -374,4 +376,9 @@ HloSharding HloSharding::TransformShardedTileShape( return HloSharding::Tile(new_tile_shape, tile_assignment()); } +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { + out << sharding.ToString(); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index e715dff9a0b8fcc2301a1581919dba384206923c..06204acbca30648e73382cb4641139e852664b77 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -94,6 +94,10 @@ class HloSharding { // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); + // Checks whether device is a reserved device number. A reserved device number + // has usually a special meaning, with dedicated handling logic. + static bool IsReservedDevice(int64 device) { return device < 0; } + OpSharding ToProto() const; string ToString() const; @@ -173,7 +177,7 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && + ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -207,6 +211,13 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } + // Returns the flattened list of all the leaf shardings in a tuple shape, by + // pre-order walk (ShapeTree iterator order). + // REQUIRES: IsTuple(). + const std::vector& tuple_elements() const { + return tuple_elements_; + } + // Return a new sharding that can apply to the given new shape. // If this sharding is tile-maximal, the returned sharding will be the same as // this sharding. If this sharding is not tile-maximal, the returned @@ -262,6 +273,8 @@ class HloSharding { std::vector tuple_elements_; }; +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 07fc4687cc1c0518b3ab2a86c62464fc54082a01..69ea4233e45c2e59c8d1541a0517a007f4bbf42f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -282,5 +282,44 @@ TEST_F(HloShardingTest, TransformShardedTileShapeTest) { EXPECT_EQ(result, expected); } +TEST_F(HloShardingTest, ToStringReplicatedTest) { + HloSharding sharding = HloSharding::Replicate(); + EXPECT_EQ(sharding.ToString(), "{replicated}"); +} + +TEST_F(HloShardingTest, ToStringAssignDeviceTest) { + HloSharding sharding = HloSharding::AssignDevice(7); + EXPECT_EQ(sharding.ToString(), "{maximal device=7}"); +} + +TEST_F(HloShardingTest, ToStringTiledTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), + Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); +} + +TEST_F(HloShardingTest, ToStringTupleTest) { + HloSharding sharding = HloSharding::Tuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), + ShapeUtil::MakeShape(U32, {7, 25}), + ShapeUtil::MakeShape(S32, {9, 11})}), + {HloSharding::Replicate(), + HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), + Array2D({{3, 5}})), + HloSharding::AssignDevice(3)}); + EXPECT_EQ(sharding.ToString(), + "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); +} + +TEST_F(HloShardingTest, OstreamTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + std::ostringstream oss; + oss << sharding; + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 0819ab3b90b2360c6b0b2afaa89f322afe566eb3..45505484951abfcee93a62fec7a99e86cbb9150c 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -63,10 +63,7 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - deps = [ - "@nsync//:nsync_headers", - "//tensorflow/core:stream_executor_headers_lib", - ] + if_static( + deps = ["//tensorflow/core:stream_executor_headers_lib"] + if_static( ["@protobuf_archive//:protobuf"], ["@protobuf_archive//:protobuf_headers"], ), @@ -123,14 +120,3 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 37261ed1e665ebed9685751161a412ad114a9e96..f1e7fc29532ce7e6841010a5258f4000a7c70383 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,17 +169,3 @@ cc_library( "@llvm//:core", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2a282f3be79f847a6569416794d1a2a3fcd69148..ec04239b4f9112134ba876fdfbb3905a3baf1f72 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -762,7 +763,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 07f989d4faea199e812e54d2ae74d3ff9e7fa19a..499f280211aacd00e79b3ca0ddb3413f933b02da 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { + +// Retrieves the parameter metadata for the given computation and parameter +// number. +// +// If the parameter number is invalid for this computation, nullopt is +// returned. When the return value has_value(), nullptr will never be +// the held value. +tensorflow::gtl::optional ParameterMetadata( + const XlaComputation& computation, int parameter_number) { + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() == computation.proto().entry_computation_id()) { + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == parameter_number) { + if (!instr.has_metadata()) { + return tensorflow::gtl::nullopt; + } + return &instr.metadata(); + } + } + } + } + return tensorflow::gtl::nullopt; +} + +ExecutionOptions CreateExecutionOptions( + const ExecutableBuildOptions& build_options, + const ProgramShape* program_shape) { + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (build_options.hlo_profile().has_value()) { + execution_options.mutable_debug_options()->set_xla_hlo_profile( + *build_options.hlo_profile()); + } + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.dump_optimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_optimized_hlo_proto_to( + build_options.dump_optimized_hlo_proto_to().value()); + } + if (build_options.dump_per_pass_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_per_pass_hlo_proto_to( + build_options.dump_per_pass_hlo_proto_to().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); + } else { + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); + } + return execution_options; +} + +} // namespace + StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -118,30 +180,78 @@ StatusOr> LocalService::CompileExecutable( *build_options.result_layout(), program_shape->result())); } - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, program_shape.get()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(*program_shape, argument_layouts, + &execution_options, user_computation)); + + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); + + return BuildExecutable(versioned_handle, std::move(module_config), + execute_backend_.get(), executor, + build_options.device_allocator()); +} + +StatusOr> LocalService::CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options) { + const HloModuleProto& proto = computation.proto(); + TF_RET_CHECK(proto.has_program_shape()); + const ProgramShape& program_shape = proto.program_shape(); + + // Validate incoming layouts. + if (argument_layouts.size() != program_shape.parameters_size()) { + return InvalidArgument( + "Invalid number of arguments for computation: expected %d, got %zu.", + program_shape.parameters_size(), argument_layouts.size()); + } + + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { + tensorflow::gtl::optional metadata = + ParameterMetadata(computation, /*parameter_number=*/i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; + return InvalidArgument( + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } } if (build_options.result_layout() != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); - } else { - *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape.result())); } + + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - *user_computation)); + CreateModuleConfig(program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(versioned_handle, std::move(module_config), + return BuildExecutable(proto, std::move(module_config), execute_backend_.get(), executor, build_options.device_allocator()); } diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 15e120685e1be9190d49fdaf5ed6706bdf991a6c..06567cabd6eb28aae53881613cd6beb78e25e222 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -50,6 +51,18 @@ class LocalService : public Service { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Builds an Executable with the given XlaComputation, argument layouts and + // options. If result_layout is non-null, then the executable is compiled to + // produce a result of the given layout. If device_allocator is non-null, + // then the compiler may use it to allocate temp space on the device. The + // compiler is responsible for freeing any memory it allocates this way. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options); + // Returns the device ordinal that corresponds to the given replica number. // // This returns an error if there is not a one-to-one correspondence of diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b0e1270702621c9ea7b2ee43e001fe0..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,8 +53,8 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } -// Returns true iff `instruction` can change its shape simply by adjusting -// metadata. +// Returns true if `instruction` can change its shape simply by adjusting +// metadata or if `instruction` is a broadcast of a scalar value. bool CanTriviallyChangeShape(const HloInstruction* instruction) { // NOTE: Technically a sequence of reshape(reshape(constant)) is also // trivially reshapable, so we might be tempted to simply recurse if @@ -88,19 +88,31 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + + // A broadcase of scalar can trivially change its shape. + if (instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape())) { + return true; + } + return false; } -// Finds the first non-scalar operand of an instruction that is a non-trivial -// reshape or transpose. Returns the operand if it is found or nullptr if not -// found. +// Returns true iff `instruction` is a reshape/transpose instruction for which +// a shape change is nontrivial. +bool IsNontrivialReshape(const HloInstruction* instruction) { + return !ShapeUtil::IsScalar(instruction->shape()) && + IsReshapeOrTranspose(instruction) && + !CanTriviallyChangeShape(instruction->operand(0)); +} + +// Finds the first operand of an instruction that is a non-trivial reshape or +// transpose. Returns such an operand or nullptr if not found. HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { - if (!ShapeUtil::IsScalar(operand->shape()) && - IsReshapeOrTranspose(operand) && - !CanTriviallyChangeShape(operand->operand(0))) { - VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " + if (IsNontrivialReshape(operand)) { + VLOG(5) << "Found first non-trivial reshape operand of " << hlo->ToString(HloPrintOptions().set_print_metadata(false)) << ":\n\t" << operand->ToString(HloPrintOptions().set_print_metadata(false)); @@ -110,7 +122,7 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( return nullptr; } -// Returns whether `a` and `b` are equivalent for the purposes of this pass. +// Returns whether `a` and `b` are equivalent reshapes/transposes. bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { if (a->opcode() != b->opcode() || !ShapeUtil::SameDimensions(a->shape(), b->shape())) { @@ -127,71 +139,14 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if all operands of `instruction` can easily change shape. -// Operands can easily change shape if they are all reshapes/transposes to and -// from the same shape. Additionally, operands like constant, rng, and any -// scalar change shape with only an adjustment of metadata. -bool AllOperandsHaveEasyShapeChanges( - const HloInstruction* instruction, - const HloInstruction* first_reshape_operand) { - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, it may be - // implicitly broadcast, which can confound the movement's - // correctness. - // - // And one of the following: - // 1. Are reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. Are one of kConstant, kRng, and scalars that can change shape - // trivially, - for (const HloInstruction* operand : instruction->operands()) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToString(print_no_metadata) << "\n\tinstruction: " - << instruction->ToString(print_no_metadata); - return false; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToString(print_no_metadata) - << "\n\toperand: " << operand->ToString(print_no_metadata); - continue; - } - - if (CanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToString(print_no_metadata); - continue; - } - - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " - << operand->ToString(print_no_metadata); - return false; - } - - VLOG(3) << "All operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - return true; -} - // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that // plays nicely with `new_operand_shape`; either it has the same shape (of the // correct type), or it is a scalar that may be implicitly broadcast. -HloInstruction* UpdateOperand(HloComputation* computation, - const HloInstruction* first_reshape_operand, +HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { + HloComputation* computation = operand->parent(); const PrimitiveType element_type = operand->shape().element_type(); const Shape new_shape = ShapeUtil::ChangeElementType(new_operand_shape, element_type); @@ -222,36 +177,24 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: { + CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape())); + HloInstruction* inst = computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + VLOG(5) << "Changing broadcast from " << operand->ToString() << " to " + << inst->ToString(); + return inst; + } + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } } -// Try to sink any reshape or transpose operands of `instruction` across it. We -// do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. -StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - // Only perform sinks for live elementwise instructions with operands. - const bool is_dead = instruction->user_count() == 0 && - instruction != computation->root_instruction(); - if (!instruction->IsElementwise() || instruction->operands().empty() || - is_dead) { - return false; - } - - // Only perform sinks if there are any nontrivial reshape/transpose operands. - const HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - if (!first_reshape_operand) { - return false; - } - - // Only perform sinks if all operands can easily change shape. - if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { - return false; - } - +// Actually performs the reshape-move transformation -- that is, sinks the +// reshape or transpose operands of `instruction` across it. +StatusOr PerformSinkReshapeOrTranspose( + HloInstruction* instruction, const HloInstruction* first_reshape_operand) { auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); @@ -272,8 +215,8 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, } VLOG(3) << "Updating operand #" << i << ": " << operands[i]->ToString(print_no_metadata); - operands[i] = UpdateOperand(computation, first_reshape_operand, - new_operand_shape, operands[i]); + operands[i] = + UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is @@ -285,6 +228,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, *shape->mutable_layout() = new_operand_shape.layout(); } } + HloComputation* computation = instruction->parent(); HloInstruction* new_elementwise = computation->AddInstruction(instruction->CloneWithNewOperands( // `instruction` may change the element type, e.g., from @@ -319,6 +263,141 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return true; } +// Returns true if the instruction is a reshape-move candidate. +// +// An instruction is a reshape-move candidate if the instruction is elementwise, +// has at least one nontrivial reshape/transpose operand, and its operands are +// either trivially reshapable or are equivalent nontrivial reshapes/transposes. +bool IsReshapeMoveCandidate(HloInstruction* instruction) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + VLOG(5) << "** Checking instruction: " + << instruction->ToString(print_no_metadata); + + // Only perform reshape-move for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { + return false; + } + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, they may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars + // that can change shape trivially. + const HloInstruction* first_reshape_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); + return false; + } + + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToString(print_no_metadata); + continue; + } + + if (!IsNontrivialReshape(operand)) { + VLOG(5) << "Operand can't trivially change shape: " + << operand->ToString(print_no_metadata); + return false; + } + + if (first_reshape_operand == nullptr) { + first_reshape_operand = operand; + VLOG(5) << "First reshape operand " + << operand->ToString(print_no_metadata); + } else if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) + << "Operand is an equivalent reshape of the first reshape operand " + << operand->ToString(print_no_metadata); + } else { + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is a reshape but is not equivalent to the first " + "Reshape operand" + << operand->ToString(print_no_metadata); + return false; + } + } + + if (first_reshape_operand) { + VLOG(5) << "All operands have easy shape changes: " + << instruction->ToString(print_no_metadata); + } + + return first_reshape_operand != nullptr; +} + +// Reshape-moves all qualifying instructions in reshape_candidates. Returns +// true if it makes changes. +// +// `reshape_candidates` is a set of HloInstructions with nontrivial reshape +// operands, and a instruction in the set can be reshape-moved iff all the users +// of its nontrivial reshape operands can also be reshaped-moved. +// +// The algorithm here iteratively finds the nontrivial operands with users that +// are outside the set of `reshape_candidates`, and removes their users from +// `reshape_candidates`, until either `reshape_candidates` becomes empty or none +// of the remaining nontrivial operands have users outside `reshape_candidates`. +// In the later case, all the remaining instructions in `reshape_candidates` +// are reshape-moved and the routine returns true. +StatusOr TryReshapeMoveOnCandidates( + HloInstructionSet* reshape_candidates) { + bool removed = true; + while (!reshape_candidates->empty() && removed) { + if (VLOG_IS_ON(5)) { + for (const HloInstruction* instruction : *reshape_candidates) { + VLOG(5) << "candidate " << instruction->ToString(); + } + } + ConstHloInstructionSet nontrivial_operands; + for (const HloInstruction* instruction : *reshape_candidates) { + for (const auto* operand : instruction->operands()) { + if (IsNontrivialReshape(operand)) { + nontrivial_operands.insert(operand); + } + } + } + + removed = false; + for (auto operand : nontrivial_operands) { + if (c_any_of(operand->users(), [&](HloInstruction* user) { + return !reshape_candidates->count(user); + })) { + for (auto* user : operand->users()) { + removed |= reshape_candidates->erase(user) > 0; + } + } + } + } + + if (reshape_candidates->empty()) { + return false; + } + for (HloInstruction* instruction : *reshape_candidates) { + const HloInstruction* first_reshape_operand = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + TF_ASSIGN_OR_RETURN( + bool did_change, + PerformSinkReshapeOrTranspose(instruction, first_reshape_operand)); + CHECK(did_change); + } + return true; +} + } // namespace StatusOr ReshapeMover::Run(HloModule* module) { @@ -326,11 +405,15 @@ StatusOr ReshapeMover::Run(HloModule* module) { VLOG(2) << "Pre ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp, instruction)); - changed |= did_change; + HloInstructionSet reshape_candidates; + for (HloInstruction* instruction : comp->instructions()) { + if (IsReshapeMoveCandidate(instruction)) { + reshape_candidates.insert(instruction); + } } + TF_ASSIGN_OR_RETURN(bool did_change, + TryReshapeMoveOnCandidates(&reshape_candidates)); + changed |= did_change; } VLOG(2) << "Post ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index aac8638a54f744f0c230ec6c5ca071c1daf45ab2..094f7319f462a71f4bfe972771a1de4aedbb8ee3 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -560,5 +560,95 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } +TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { + const string hlo_string = R"( + HloModule TransposeMulInversedTransposeModule + ENTRY TransposeMulInversedTranspose { + src0 = f32[20,8]{1,0} parameter(0) + transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0} + src1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Transpose(op::Multiply())); +} + +TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { + const string hlo_string = R"( + HloModule ReshapeWithUsersOutsideCandidates + ENTRY ReshapeWithMultipleUsers { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + param4 = f32[8,20]{1,0} parameter(4) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, param4) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates1 + ENTRY ReshapeWithMultipleUsers1 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, reshape2) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates2 + ENTRY ReshapeWithMultipleUsers2 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Reshape(op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 0becc9d8f8ed22b2d7174b76ce775efec4b646f5..ec883a6cf3ce9546ac54f5c2524a8eda53bad33f 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -272,7 +272,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -286,8 +286,15 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { + if (user_computation == nullptr) { + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + } return InvalidParameterArgument( - *user_computation.ParameterMetadata(i).value(), + *user_computation->ParameterMetadata(i).value(), "Argument does not match shape of computation parameter %d: want %s, " "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), @@ -330,7 +337,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); @@ -402,6 +409,37 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } +StatusOr>> Service::BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf("BuildExecutable on service %p", this); + + VLOG(1) << "Computations:"; + for (const HloModuleProto* proto : module_protos) { + VLOG(1) << proto->name(); + } + + CHECK_EQ(module_protos.size(), module_configs.size()); + std::vector> modules; + for (int64 i = 0; i < module_protos.size(); ++i) { + const HloModuleProto* proto = module_protos[i]; + const HloModuleConfig& config = *module_configs[i]; + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(*proto, config)); + modules.push_back(std::move(module)); + } + + TF_ASSIGN_OR_RETURN( + std::vector> executables, + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); + + return std::move(executables); +} + StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, @@ -696,6 +734,47 @@ tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, return computation->SetReturnValue(arg->operand()); } +StatusOr> +Service::GetExecutors(const ExecutionOptions& execution_options, + int64 requests_size, int64 request_index) const { + if (execution_options.device_handles().empty()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); + } + if (requests_size > 1 && execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %lld parallel requests, with request %lld containing %d device " + "handles.", + requests_size, request_index, execution_options.device_handles_size()); + } + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } + return executors; +} + +StatusOr>> Service::GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments) { + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arguments, replicas)); + return replicated_arguments; +} + tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); @@ -724,26 +803,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // is one of the executors to run the replicated computation. const ExecutionOptions& execution_options = arg->requests(i).execution_options(); - if (execution_options.device_handles().empty()) { - return FailedPrecondition( - "device handles must be given to execute parallel computations"); - } - if (arg->requests_size() > 1 && - execution_options.device_handles_size() > 1) { - return InvalidArgument( - "Parallel requests with multiple device handles is not supported. " - "Found %d parallel requests, with request %lld containing %d device " - "handles.", - arg->requests_size(), i, execution_options.device_handles_size()); - } - std::vector executors; - for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); - executors.push_back(executor); - } + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -760,16 +823,9 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); - // Resolve the allocations for the arguments of the computation, and create - // a vector of device memory offsets for the arguments from the allocations. - // In the case of partitioned computations, assume all arguments go on the - // zeroth core. - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, execution_options.device_handles(0))); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(request.arguments(), replicas)); + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. Here, we care only about the @@ -778,7 +834,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), *user_computation)); + request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -830,6 +886,107 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { + VLOG(1) << "running execute-graph-parallel request"; + + std::vector>> all_arguments; + std::vector> all_executors; + std::vector module_protos; + std::vector> module_configs; + std::vector computation_names; + std::vector device_handles; + + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteGraphRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > + execute_backend_->device_count()) { + return FailedPrecondition( + "there are not enough stream executors to execute %d computations", + num_requested_devices); + } + + for (int64 i = 0; i < arg->requests_size(); ++i) { + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + const ExecuteGraphRequest& request = arg->requests(i); + TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; + TF_RET_CHECK(request.computation().has_program_shape()) + << "programe shape may not be empty"; + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); + + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); + + // Create an HloModuleConfig object for the computation, given the shape of + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(request.computation().program_shape(), + replicated_arguments.front(), + request.execution_options(), + /*user_computation=*/nullptr)); + VLOG(3) + << "ExecuteGraphParallel created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + // Adds to the vectors to build and execute the computations after the loop. + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); + module_protos.push_back(&request.computation()); + module_configs.push_back(std::move(module_config)); + computation_names.insert(computation_names.end(), executors.size(), + request.computation().name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); + } + + // Build the HloModules and compile to generate the executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteGraphParallel, so we have to pass a null device_allocator below. + TF_ASSIGN_OR_RETURN(std::vector> executables, + BuildExecutables(module_protos, std::move(module_configs), + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); + std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); + for (const auto& executable : executables) { + executable_ptrs.push_back(executable.get()); + } + + // Execute the generated executables in parallel and return the device + // handles for each computation's output. + ExecutionProfile profile; + TF_ASSIGN_OR_RETURN( + std::vector outputs, + ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, + execute_backend_.get(), device_handles, + computation_names, &profile)); + for (const GlobalDataHandle& output : outputs) { + ExecuteResponse response; + *response.mutable_output() = output; + *response.mutable_profile() = profile; + *result->add_responses() = response; + } + + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; + return tensorflow::Status::OK(); +} + tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); @@ -854,6 +1011,47 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + ExecuteGraphParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::PickParallelResponse( + const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } + } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} + tensorflow::Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); @@ -870,13 +1068,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); + return ExecuteOneToN(arg, result); } TF_ASSIGN_OR_RETURN( @@ -894,7 +1086,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), *user_computation)); + arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -935,9 +1127,72 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return tensorflow::Status::OK(); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* /*arg*/, - ExecuteResponse* /*result*/) { - return Unimplemented("execute-graph is not yet implemented"); +StatusOr> Service::BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf( + "BuildExecutable on service %p with serialized module proto: %s", this, + module_proto.name().c_str()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(module_proto, *module_config)); + + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); + + return std::move(executable); +} + +tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute-graph request"; + + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("programe shape may not be empty"); + } + + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + return ExecuteOneToN(arg, result); + } + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(arg->computation().program_shape(), + replicated_arguments.front(), + arg->execution_options())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + BuildExecutable(arg->computation(), std::move(module_config), + execute_backend_.get(), + execute_backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + "result of " + arg->computation().name(), result->mutable_profile())); + + VLOG(1) << "successfully completed 'execute-graph' request"; + return tensorflow::Status::OK(); } tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, @@ -967,7 +1222,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), *user_computation)); + arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1268,7 +1523,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(program_shape, {}, execution_options, - *user_computation)); + user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1360,6 +1615,29 @@ tensorflow::Status Service::GetComputationStats( return tensorflow::Status::OK(); } +tensorflow::Status Service::GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { + HloModuleConfig config; + config.set_debug_options(arg->debug_options()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); + + hlo_graph_dumper::MaybeDumpHloModule(*module, + "computation statistics subject"); + + // Run HLO analysis to get the computation statistics. + HloCostAnalysis analysis( + execute_backend_->compiler()->ShapeSizeBytesFunction()); + + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); + + ComputationStats stats; + stats.set_flop_count(analysis.flop_count()); + stats.set_transcendental_count(analysis.transcendental_count()); + *result->mutable_stats() = stats; + return tensorflow::Status::OK(); +} + template tensorflow::Status Service::AddInstruction( const RequestT* arg, ResponseT* result, diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 96352d9096e6aeeb33f84c7b6fc42c28820e5e84..e09d58bbe7691b4854538ca5a99bd4c0b8d53c3b 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -115,6 +115,8 @@ class Service : public ServiceInterface { // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; @@ -124,6 +126,15 @@ class Service : public ServiceInterface { tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) override; + // Executes one or more computations in parallel with the provided global data + // passed as immutable arguments. Returns global data output for each + // computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; + // Requests one or more device handles from the target. // // When N device handles are requested and the number of replicas is R, at @@ -222,6 +233,13 @@ class Service : public ServiceInterface { const ComputationStatsRequest* arg, ComputationStatsResponse* result) override; + // Retrieves the statistics of a computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; + // Snapshots the current state of a computation handle into a serializable // protocol buffer form, so it can be loaded via // LoadComputationSnapshot. @@ -258,7 +276,21 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); + + // Picks a parallel response and fills the result. + Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, + ExecuteResponse* result); + + // Prepare the executors for executing parallel. + StatusOr> GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const; + + // Prepare the arguments for executing parallel. + StatusOr>> GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments); protected: friend class LocalExecutable; @@ -286,7 +318,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); // Builds an Executable for the given parameters. // @@ -299,6 +331,15 @@ class Service : public ServiceInterface { perftools::gputools::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); + // Builds an Executable for the given HLO module proto. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); + // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. StatusOr>> BuildExecutables( @@ -307,6 +348,12 @@ class Service : public ServiceInterface { Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); + StatusOr>> BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and @@ -346,6 +393,14 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result); + tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result); + // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8c8bd6d73ad41db7d609ac91c7bdfc4703f364e1..77e12d36024dae56003ad4e59b54f9934dfc2c58 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -304,12 +304,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { + return InferUnaryOpShape(opcode, operand->shape()); +} + +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. if (opcode == HloOpcode::kCopy) { - return operand->shape(); + return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape()); + return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); } /* static */ StatusOr ShapeInference::InferUnaryOpShape( @@ -1033,8 +1038,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(), - rhs->shape(), ehs->shape()); + return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); +} + +/* static */ StatusOr ShapeInference::InferTernaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { + return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); } /* static */ StatusOr ShapeInference::InferTernaryOpShape( @@ -1061,6 +1070,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } + return InferVariadicOpShape(opcode, operand_shapes); +} + +/* static */ StatusOr ShapeInference::InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operand_shapes) { return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), operand_shapes); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 085fdac60c6de161c457dff672175e82f4f4da51..9da2c99b4177f08ece8daabaf2922ddd7e947a1b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -48,6 +48,8 @@ class ShapeInference { // given input shape. static StatusOr InferUnaryOpShape(UnaryOperation operation, const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, const HloInstruction* operand); @@ -68,6 +70,9 @@ class ShapeInference { static StatusOr InferTernaryOpShape(TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs); + static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, + const Shape& rhs, + const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, @@ -78,6 +83,9 @@ class ShapeInference { static StatusOr InferVariadicOpShape( VariadicOperation operation, tensorflow::gtl::ArraySlice operand_shapes); + static StatusOr InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operands); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 0dca30a804005c6f536aca5b54af24eb08d4560b..532f7fd5bfc1dffa86638a6bc51832beebd74e1d 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1284,8 +1284,8 @@ StatusOr UserComputation::AddCustomCallInstruction( TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); } - if (tensorflow::StringPiece(custom_call_request.call_target_name()) - .starts_with("$")) { + if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), + "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", @@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit( HloInstruction* operand = lookup_instruction(trace_request.operand()); hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); break; } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index d3d55634c97bbdf3f81321d8089bb808c411340b..3d3e1d60f294c3a2574513c1c2f071805a341ad1 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. -// - A while loops with static trip count of 1 is replaced by its body (sans +// - A while loop with static trip count of 1 is replaced by its body (sans // loop). // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index f1fea6d7634f2060abe18c0fdd51a3391dcb5ae3..619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -68,7 +68,7 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { hlo_string_template, "{{LOOP_BOUND}}", tensorflow::strings::StrCat(42 + num_iters), /*replace_all=*/true); - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); } void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( @@ -107,7 +107,7 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( hlo_string_template, "{{LOOP_BOUND}}", tensorflow::strings::StrCat(42 + num_iters), /*replace_all=*/true); - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { @@ -235,7 +235,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -267,7 +267,7 @@ TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -296,7 +296,7 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -319,7 +319,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -347,7 +347,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -389,7 +389,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); HloModule* the_module = &module(); EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -439,7 +439,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -472,7 +472,7 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -504,7 +504,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string.c_str()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 7441a7ad395bf185cd31de7d4b57beae66cc3063..bd0794184328b7926543c4275b3b915f51e7b812 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -142,23 +142,23 @@ WhileUtil::MakeInstructionsLiveIn( static StatusOr> MakeCountedLoopConditionComputation(const Shape& loop_state_shape, - int64 trip_count) { + int32 trip_count) { Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); - Shape scalar_s64 = ShapeUtil::MakeShape(S64, {}); TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, CreateComputationWithSignature( {&loop_state_shape}, scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); - TF_ASSIGN_OR_RETURN(HloInstruction * counter, - CreateGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN( HloInstruction * compare, - CreateBinaryHlo(HloOpcode::kLt, counter, trip_count_constant)); + MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); return std::move(cond_computation); } @@ -171,18 +171,17 @@ static StatusOr> MakeCountedLoopBodyComputation( CreateComputationWithSignature( {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); - + HloInstruction::CreateConstant(Literal::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, - CreateGetTupleElementHlo(param, 0)); + MakeGetTupleElementHlo(param, 0)); TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar, - CreateBinaryHlo(HloOpcode::kAdd, indvar, one)); + MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); std::vector loop_body_generator_args; for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) { TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element, - CreateGetTupleElementHlo(param, i)); + MakeGetTupleElementHlo(param, i)); loop_body_generator_args.push_back(tuple_element); } TF_ASSIGN_OR_RETURN(std::vector next_state, @@ -200,7 +199,7 @@ static StatusOr MakeInitTupleFromInitValues( std::vector init_values_with_indvar; init_values_with_indvar.reserve(init_values.size() + 1); HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); init_values_with_indvar.push_back(zero); c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( @@ -210,16 +209,18 @@ static StatusOr MakeInitTupleFromInitValues( static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); - loop_state_shape_components.push_back(ShapeUtil::MakeShape(S64, {})); + loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); c_transform(init_values, std::back_inserter(loop_state_shape_components), [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } /*static*/ StatusOr WhileUtil::MakeCountedLoop( - HloComputation* computation, int64 trip_count, + HloComputation* computation, int32 trip_count, const WhileUtil::LoopStateTy& init_values, const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + CHECK_GE(trip_count, 0); + Shape loop_state_shape = MakeLoopStateShape(init_values); TF_ASSIGN_OR_RETURN( std::unique_ptr cond, @@ -238,7 +239,7 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector result; for (int64 i = 0, e = init_values.size(); i < e; i++) { TF_ASSIGN_OR_RETURN(HloInstruction * user_state, - CreateGetTupleElementHlo(while_instr, i + 1)); + MakeGetTupleElementHlo(while_instr, i + 1)); result.push_back(user_state); } return result; diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 80f7e16e64f4d1b1faa73f4fb9b4dd6443bf488b..1688d4674269c36c5b356f262dbd5d958572e101 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -71,7 +71,7 @@ class WhileUtil { // return loop_state; // } static StatusOr MakeCountedLoop( - HloComputation* computation, int64 trip_count, + HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, const LoopBodyGeneratorTy& loop_body_generator); }; diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 063e312df66ce9cba0fa9f49c2fc6026ba6b74aa..8763e588c484011ba2ccbc7cad8f29817347a605 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index d8235113dd800f7bab5ceb70272a598b9dcb1fbe..32aae64973dbd7ac2f8d403d8fbd155d432642f9 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -60,6 +60,10 @@ class ServiceInterface { virtual tensorflow::Status ExecuteParallel( const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual tensorflow::Status ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; + virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) = 0; @@ -72,6 +76,10 @@ class ServiceInterface { virtual tensorflow::Status GetComputationStats( const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual tensorflow::Status GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) = 0; + virtual tensorflow::Status GetComputationShape( const GetComputationShapeRequest* arg, GetComputationShapeResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 4f604e6f7cb18c1aaf844967d54e3b0e07e54b34..6825d2476587d037aace043230168f78f4e46344 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -502,11 +502,11 @@ namespace { StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { tensorflow::str_util::RemoveLeadingWhitespace(s); - if (s->Consume("(")) { // Tuple. + if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (s->Consume(")")) { + if (tensorflow::str_util::ConsumePrefix(s, ")")) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", @@ -515,7 +515,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !s->Consume(","); + must_end = !tensorflow::str_util::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 025ac129d7040a007493cbb222d07c6cf323567f..6f58c20f34e30324ca36dbc7fa78ebb82a4b435d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -190,6 +190,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -346,10 +347,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -386,6 +387,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -596,6 +598,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -676,7 +679,9 @@ xla_test( name = "gather_operation_test", srcs = ["gather_operation_test.cc"], deps = [ + ":client_library_test_base", ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -932,8 +937,8 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -972,9 +977,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1006,7 +1010,10 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -1369,6 +1376,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1435,9 +1443,9 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1557,6 +1565,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1803,9 +1813,8 @@ tf_cc_test( deps = [ ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1952,17 +1961,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 6e21dda25d8e5151b31b8c2328253260595a94c4..03c91745b978f80801e0da5ac44d31959659b20c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,28 +51,28 @@ class ArrayElementwiseOpTestParamCount public ::testing::WithParamInterface {}; XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1, 324, std::numeric_limits::min(), std::numeric_limits::max()}); - auto result = builder.Neg(a); + builder.Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -83,18 +84,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, @@ -102,7 +103,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({ -1, 1, @@ -112,7 +113,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { static_cast(0x8000000000000000LL), static_cast(0x8000000000000001LL), }); - auto result = builder.Neg(a); + builder.Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, @@ -129,9 +130,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -140,64 +141,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.IsFinite(builder.ConstantR0(NAN)); + XlaBuilder builder(TestName()); + builder.IsFinite(builder.ConstantR0(NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto result_non_canonical = - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - auto result_inf = builder.IsFinite(builder.ConstantR0(inf)); + builder.IsFinite(builder.ConstantR0(inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_neg_inf = builder.IsFinite(builder.ConstantR0(-inf)); + builder.IsFinite(builder.ConstantR0(-inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_zero = builder.IsFinite(builder.ConstantR0(0.0f)); + builder.IsFinite(builder.ConstantR0(0.0f)); ComputeAndCompareR0(&builder, true, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); auto a = builder.ConstantR1( {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); auto b = builder.ConstantR1( {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -205,10 +205,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -244,7 +244,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto add = b.Add(lhs_param, rhs_param); + b.Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -295,7 +295,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector a_values; std::vector b_values; for (int i = 0; i < count; ++i) { @@ -334,49 +334,49 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); auto b = builder.ConstantR1({-1, 2, 1, -1}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -384,29 +384,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -436,9 +436,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -451,8 +451,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -461,9 +461,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -476,8 +476,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -507,9 +507,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -521,8 +521,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -531,9 +531,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -545,8 +545,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -556,33 +556,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); auto b = builder.ConstantR1( {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); auto b = builder.ConstantR1( {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -590,21 +590,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); auto b = builder.ConstantR1( {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -612,20 +612,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -648,19 +648,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -679,21 +679,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -701,264 +701,264 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, -8}); auto b = builder.ConstantR1({5, -7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {3, 8}}); auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, 8}); auto b = builder.ConstantR1({5, -7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -1}, {8, 8}}); auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {8, 8}}); auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true}, {true, false}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 4294967295}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x12345678), static_cast(0xF0001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); - auto out = builder.ShiftLeft(a, b); + builder.ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, @@ -967,12 +967,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); - auto out = builder.ShiftRightArithmetic(a, b); + builder.ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, @@ -982,45 +982,45 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); - auto out = builder.ShiftRightLogical(a, b); + builder.ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); - auto out = builder.ShiftLeft(a, b); + builder.ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); - auto out = builder.ShiftRightArithmetic(a, b); + builder.ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); - auto out = builder.ShiftRightLogical(a, b); + builder.ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -1028,59 +1028,59 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -1088,10 +1088,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1099,17 +1099,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1120,16 +1120,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1138,7 +1138,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1149,7 +1149,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1158,10 +1158,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1169,10 +1169,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1181,10 +1181,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1193,10 +1193,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1206,10 +1206,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1218,10 +1218,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1230,10 +1230,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1242,10 +1242,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1253,10 +1253,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1264,10 +1264,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1276,10 +1276,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1287,10 +1287,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1299,12 +1299,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1312,20 +1312,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1599,14 +1599,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector values; values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } auto x = builder.ConstantR1(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1618,7 +1618,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 2, 2); std::vector values_vector; @@ -1632,77 +1632,77 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - auto maximum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - auto maximum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, error_spec_); @@ -1711,7 +1711,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1726,7 +1726,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1740,7 +1740,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Max(x, y); @@ -1751,7 +1751,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Min(x, y); @@ -1761,7 +1761,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( @@ -1774,7 +1774,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR1({}); builder.Max(u, v); @@ -1784,7 +1784,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); @@ -1794,7 +1794,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); @@ -1805,7 +1805,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({}); auto m = builder.ConstantR2({{}, {}}); builder.Max(v, m, /*broadcast_dimensions=*/{1}); @@ -1815,7 +1815,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1826,7 +1826,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d(2, 0, 3); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1837,7 +1837,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); @@ -1848,7 +1848,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{}, {}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); builder.Min(m, v, /*broadcast_dimensions=*/{0}); @@ -1858,7 +1858,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); auto array4d = builder.ConstantR4FromArray4D( @@ -1873,7 +1873,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); @@ -1885,7 +1885,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Min(x, y); @@ -1895,7 +1895,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Max(x, y); @@ -1905,110 +1905,107 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR0(0.0f); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto maximum = builder.ConstantR0(5.0f); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2022,7 +2019,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -2030,7 +2027,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); @@ -2044,7 +2041,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -2052,7 +2049,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2061,35 +2058,35 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - auto add = builder.Add(a, p); + builder.Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Cos(a); + builder.Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Sin(a); + builder.Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - auto atan = builder.Atan2(a, b); + builder.Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2098,9 +2095,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - auto result = builder.Tanh(a); + builder.Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2110,7 +2107,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, @@ -2149,7 +2146,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. @@ -2185,7 +2182,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, @@ -2225,14 +2222,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // / / // b -----/ / // c---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(a, b); - auto add2 = builder.Add(add, c); + builder.Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2243,14 +2240,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // / / // c -----/ / // a---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(b, c); - auto add2 = builder.Add(a, add); + builder.Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2260,14 +2257,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // a ----- (neg) ----- (add) // / // b ----- (neg) ----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto neg_a = builder.Neg(a); auto neg_b = builder.Neg(b); - auto result = builder.Add(neg_a, neg_b); + builder.Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2281,7 +2278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // c ------ (add) ------------/ // / // d -----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); @@ -2290,19 +2287,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { auto add_ab = builder.Add(a, b); auto add_cd = builder.Add(c, d); - auto add_all = builder.Add(add_ab, add_cd); + builder.Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2311,11 +2308,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(scalar, a); + builder.Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2323,11 +2320,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(a, scalar); + builder.Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2336,14 +2333,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); // clang-format off auto m = builder.ConstantR2({ {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2369,10 +2366,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + builder.Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2383,10 +2380,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + builder.Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2397,10 +2394,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + builder.Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2411,10 +2408,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + builder.Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2425,10 +2422,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + builder.Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2440,24 +2437,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + builder.Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2465,14 +2462,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2483,13 +2480,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // effectively creates an "outer product" operation. // This is taken from the Numpy docs example at: // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2500,10 +2497,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2511,17 +2508,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + builder.Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { // Binary add of two R3s together - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2529,7 +2526,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto add = builder.Add(a, b); + builder.Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2540,7 +2537,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2553,7 +2550,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + builder.Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2564,7 +2561,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2577,7 +2574,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + builder.Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2595,7 +2592,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} // for broadcasting. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2610,7 +2607,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2627,7 +2624,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { // Comparison between two 3D arrays of compatible shapes: // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2635,7 +2632,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto compare = builder.Gt(a, b); + builder.Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2651,7 +2648,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { } XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); @@ -2672,13 +2669,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); @@ -2700,7 +2697,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR1(operand_b_1d); - auto add = builder.Add(a, b, {1}); + builder.Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2715,7 +2712,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::vector r1(d1); std::iota(r1.begin(), r1.end(), 1.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); @@ -2736,11 +2733,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { // Show that we can't add two opaques. XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); auto x = builder.Parameter(0, shape, "x"); - auto concatenated = builder.Add(x, x); - StatusOr computation_status = builder.Build(); + builder.Add(x, x); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( @@ -2748,12 +2745,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2761,14 +2758,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { } XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); - StatusOr computation_status = builder.Build(); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().error_message(), ::testing::ContainsRegex("must.*be the identity")); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 3f6fd7c65d3360a622dbf754833009fb20410535..ec3b46acfec0ee0ff514a862ce5b1ca74279efa8 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,11 +29,11 @@ namespace { class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { - ComputationBuilder builder(client_, "ax_10"); + XlaBuilder builder("ax_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Mul(alpha, x); + builder.Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -46,7 +47,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { auto x = builder.ConstantR1({}); auto y = builder.ConstantR1({}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -60,7 +61,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto y = builder.ConstantR1( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 28ab9654997728fbafd6610af840e721e72cce5a..af8af99c791e2a40cfcfa2291b786b33e5652267 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -69,6 +69,17 @@ class BatchNormalizationTest CHECK_EQ(kY, input_array_.width()); } + ComputationDataHandle CheckShape(ComputationBuilder* b, + const ComputationDataHandle& operand, + const Shape& expected_shape) const { + std::unique_ptr actual_shape = + b->GetShape(operand).ConsumeValueOrDie(); + CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) + << "want " << ShapeUtil::HumanString(expected_shape) << " got " + << ShapeUtil::HumanString(*actual_shape); + return operand; + } + static constexpr int64 kSamples = 3; static constexpr int64 kX = 1; static constexpr int64 kY = 1; @@ -164,14 +175,15 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { ComputationBuilder builder(client_, "batch_normalize_per_spec"); auto input_activations = - builder.CheckShape(builder.ConstantLiteral(input_literal_), - ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + CheckShape(&builder, builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); auto gamma = builder.ConstantR1({1.0, 1.0}); auto beta = builder.ConstantR1({0.0, 0.0}); Computation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); - auto sum = builder.CheckShape( + auto sum = CheckShape( + &builder, builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); @@ -187,14 +199,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.CheckShape( + auto sum_of_squares = CheckShape( + &builder, builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = builder.Div(sum_of_squares, count); auto standard_deviation = builder.SqrtF32(variance); - auto standard_deviation_above_epsilon = builder.CheckShape( - builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto standard_deviation_above_epsilon = + CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = builder.Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); auto normalization_factors = builder.ReciprocalF32(gt_eps); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 610302ac1256a57db6ed6e18016a4136973e3891..eac2eb286c3f7a1cd33aed03686e99ef753b773a 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -137,7 +137,8 @@ def xla_test(name, backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] this_backend_tags += ["requires-gpu-sm35"] elif backend in plugins: - backend_deps = plugins[backend]["deps"] + backend_deps = [] + backend_deps += plugins[backend]["deps"] this_backend_copts += plugins[backend]["copts"] this_backend_tags += plugins[backend]["tags"] this_backend_args += plugins[backend]["args"] diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a677986cd926cc0054d8f36abc98ccac33dc043d..17c6a83c1a3153f78da7f5f6c9b76542bc564203 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -95,6 +95,20 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + return client_->ExecuteAndTransfer(computation, arguments, + &execution_options); +} + +template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, @@ -104,6 +118,15 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } +template <> +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -116,14 +139,31 @@ std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); } +string ClientLibraryTestBase::ExecuteToString( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status().ToString(); + } + auto computation = computation_status.ConsumeValueOrDie(); + + auto result = + client_->ExecuteAndTransfer(computation, arguments, &execution_options_); + if (!result.ok()) { + return result.status().ToString(); + } else { + return result.ValueOrDie()->ToString(); + } +} + string ClientLibraryTestBase::ExecuteToString( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { - StatusOr computation_status = builder->Build(); + auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); } - Computation computation = computation_status.ConsumeValueOrDie(); + auto computation = computation_status.ConsumeValueOrDie(); auto result = client_->ExecuteAndTransfer(computation, arguments, &execution_options_); @@ -142,16 +182,18 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -249,8 +291,28 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/, + const Shape* /*output_with_layout*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -307,8 +369,9 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( return tensorflow::Status::OK(); } +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -378,8 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( EXPECT_EQ(expected, actual->GetR1U8AsString()); } +template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -390,8 +454,9 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectEqual(expected, *actual); } +template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -522,33 +587,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - ComputationDataHandle ClientLibraryTestBase::AddParam( const Literal& argument, ComputationBuilder* builder) { ComputationDataHandle data_handle; @@ -563,4 +601,46 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, + XlaBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ba0319990bc04196386e6812b0a03671676698ec..52f31b06698a424929df0ea1425ca66b5ac96a18 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -94,15 +95,25 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); + + // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once + // the migration to XlaBuilder is complete. + + template StatusOr> ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, + BuilderT* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); + // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( ComputationBuilder* builder, @@ -113,29 +124,31 @@ class ClientLibraryTestBase : public ::testing::Test { // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. + string ExecuteToString(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); string ExecuteToString(ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder for details). For each rank, two forms are provided: one - // for floating point types with an ErrorSpec parameter, and one for integral - // types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // provided: one for floating point types with an ErrorSpec parameter, and one + // for integral types without the ErrorSpec parameter. + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); @@ -146,55 +159,53 @@ class ClientLibraryTestBase : public ::testing::Test { const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -206,11 +217,13 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. + template void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); + template void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result @@ -266,17 +279,19 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + BuilderT* builder, HandleT* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters @@ -297,6 +312,7 @@ class ClientLibraryTestBase : public ::testing::Test { // will be converted to BF16s. ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, ComputationBuilder* builder); + XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be @@ -307,6 +323,12 @@ class ClientLibraryTestBase : public ::testing::Test { return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); } + template + XlaOp CreateConstantFromArray(const Array& array, + XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + // Same as CreateConstantFromArray, but for scalars. template ComputationDataHandle CreateConstantFromScalar(NativeT value, @@ -315,6 +337,12 @@ class ClientLibraryTestBase : public ::testing::Test { builder); } + template + XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // @@ -323,10 +351,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template - std::unique_ptr CreateR0Parameter( - NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + template + std::unique_ptr CreateR0Parameter(NativeT value, + int64 parameter_number, + const string& name, + BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -336,11 +366,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -351,11 +380,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -366,11 +394,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -399,6 +426,18 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output); + tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output, + const Shape* output_with_layout = nullptr); + // Executes the computation and calculates the expected reference value using // the HloEvaluator. Returns two literal in the order of (expected, actual). StatusOr, std::unique_ptr>> @@ -414,9 +453,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -424,9 +463,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -440,9 +479,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -450,9 +489,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -466,9 +505,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -476,9 +515,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -492,9 +531,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -502,9 +541,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -518,9 +557,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -528,9 +567,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -544,10 +583,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { + BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -558,11 +597,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -573,11 +611,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -588,11 +625,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -628,6 +664,37 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + BuilderT* builder, + HandleT* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 045148cdd11da94ae4789a753efca95c6aaa1f27..32e2f2c0848407ec46a5ac52e2668ef27b92c426 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -109,14 +111,14 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XLA_TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { - Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr const_arg, client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Add(b.Parameter(0, shape, "param_0"), b.ConstantR2({{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); @@ -124,14 +126,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::ComputationInstance( + computation_instances.push_back(Client::XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ec2c580670cfac14ba42e8c9a836c86551af4b89..e5a03b49ad259a64b9cbbc88c31d8c6558289d1b 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -167,8 +168,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } @@ -183,8 +184,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index fb0e9c724a69b61801e6e0c2d07ef75b63a00465..a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -38,9 +38,9 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { - ComputationBuilder builder(client_, TestName()); - auto concatenated = builder.ConcatInDim({}, 0); - StatusOr computation_status = builder.Build(); + XlaBuilder builder(TestName()); + builder.ConcatInDim({}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("Concatenate expects at least one argument")); @@ -48,18 +48,18 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -68,51 +68,51 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(42.0); auto b = builder.ConstantR0(64.0); - auto concatenated = builder.ConcatInDim({a, b}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({a, b}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -129,20 +129,20 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { expected[253 + i] = rhs[i] = 253 + i + 1; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(lhs); auto b = builder.ConstantR1(rhs); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto concatenated = builder.ConcatInDim({a, b}, dim); + builder.ConcatInDim({a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -150,26 +150,27 @@ XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ - {0}, {64}, + {0}, + {64}, }); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 64}, @@ -178,22 +179,22 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { } XLA_TEST_F(ConcatTest, Concat2x0With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat2x3With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -203,22 +204,22 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { } XLA_TEST_F(ConcatTest, Concat3x2With0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat3x2With5x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ {0, 1}, @@ -234,16 +235,16 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { } XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -258,27 +259,29 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { }); auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); Array3D expected({ - {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, + {{0, 1, 6}}, + {{2, 3, 7}}, + {{4, 5, 8}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b, c}, 0); + builder.ConcatInDim({a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -300,35 +303,35 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); auto c = builder.ConstantR3FromArray3D(c_array); - auto concatenated = builder.ConcatInDim({a, b, c}, 2); + builder.ConcatInDim({a, b, c}, 2); Array3D expected({ - {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, + {{0, 1, 2, 3}}, + {{4, 5, 6, 7}}, + {{8, 9, 10, 11}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = (a concat b) concat c - auto concatenated = - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = a concat (b concat c) - auto concatenated = - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -342,7 +345,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 0); @@ -363,7 +366,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -388,7 +391,7 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -404,13 +407,13 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { // Show that we can't concatenate with an opaques. XLA_TEST_F(ConcatTest, CannotConcatOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); auto x = builder.Parameter(0, r1f32, "x"); auto y = builder.Parameter(1, opaque_shape, "y"); - auto concatenated = builder.ConcatInDim({x, y}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), @@ -418,23 +421,23 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p0 = builder.ConstantR1({true}); auto p1 = builder.ConstantR1({false}); auto p2 = builder.ConstantR1({true}); - auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); + builder.ConcatInDim({p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR1({1}); auto a1 = builder.ConstantR1({2, 3}); auto a2 = builder.ConstantR1({4, 5, 6}); auto a3 = builder.ConstantR1({7, 8, 9, 10}); - auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); + builder.ConcatInDim({a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -442,7 +445,7 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { } XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D arr0(9, 17, 1); arr0.Fill(1); @@ -462,14 +465,14 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { } } - ComputationDataHandle h0; + XlaOp h0; auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - auto concatenated = builder.ConcatInDim({h0, h1}, 2); + builder.ConcatInDim({h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -495,7 +498,7 @@ TEST_P(ConcatR2BinaryTest, DoIt) { Array2D rhs(spec.rhs_dim0, spec.rhs_dim1); rhs.FillUnique(1000); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR2FromArray2D(lhs); auto a1 = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a0, a1}, spec.concat_dimension); @@ -521,7 +524,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, f32_scalar, "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto mul = builder.Mul(x, y); @@ -545,7 +548,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "z"); @@ -573,7 +576,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "y"); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index bc821674820fb128823786d7149037fc59b22ab6..b917dee77b5400db8f2c0a6a86258fee64723d71 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -571,5 +571,56 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { "only parameter of true_computation")); } +XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); + Computation swapper; + { + ComputationBuilder builder(client_, TestName() + ".swapper"); + auto param0 = builder.Parameter(0, tuple_shape, "sp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({y, x}); + swapper = builder.Build().ConsumeValueOrDie(); + } + Computation forwarder; + { + ComputationBuilder builder(client_, TestName() + ".forwarder"); + auto param0 = builder.Parameter(0, tuple_shape, "fp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({x, y}); + forwarder = builder.Build().ConsumeValueOrDie(); + } + Computation main; + { + ComputationBuilder builder(client_, TestName() + ".main"); + auto param0 = builder.Parameter(0, tuple_shape, "mp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + auto lt_pred = builder.Lt(x, y); + auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = builder.Ge(x, y); + builder.Conditional(ge_pred, res, swapper, res, forwarder); + main = builder.Build().ConsumeValueOrDie(); + } + + auto test_swap = [&](float a, float b) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(a); + auto y = builder.ConstantR0(b); + auto tuple_operand = builder.Tuple({x, y}); + builder.Call(main, {tuple_operand}); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(a).get(), + Literal::CreateR0(b).get()}), + {}, error_spec_); + }; + + test_swap(3.11f, 9.4f); + test_swap(11.24f, 5.55f); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 59d6d7a4153be1b76ed8195a12a90cb103baa422..0842a8918bcfec037ab0f9aa24014c7d8296cdf8 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -177,6 +178,24 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0.0f, 1.0f, 16777216.0f, + 16777218.0f, 2147483647.0f, 4294967040.0f}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, U32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { ComputationBuilder builder(client_, TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; @@ -211,6 +230,43 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + // Test cases from compiler_rt library. + std::vector arg{0.0f, + 0.5f, + 0.99f, + 1.0f, + 1.5f, + 1.99f, + 2.0f, + 2.01f, + 2147483648.f, + -0.5f, + -0.99f, + -1.0f, + -1.5f, + -1.99f, + -2.0f, + -2.01f, + 0x1.FFFFFEp+62F, + 0x1.FFFFFCp+62F, + -0x1.FFFFFEp+62F, + -0x1.FFFFFCp+62F}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({32, 64}); @@ -366,5 +422,44 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); } + +XLA_TEST_F(ConvertTest, ConvertC64ToC64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42.0f, 64.0f}}; + builder.ConvertElementType(builder.ConstantR1(x), C64); + ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConvertTest, ConvertS64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{-42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), S64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), U64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector unsigned_x = {{42, UINT64_MAX}}; + builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + std::vector signed_x = {{42, -1}}; + ComputeAndCompareR1(&builder, signed_x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertS64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector signed_x = {{42, -1, INT64_MIN}}; + builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + std::vector unsigned_x = { + {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; + ComputeAndCompareR1(&builder, unsigned_x, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 09b1dd283e4d026a2f0007240d88cd9ac38acb19..7b994a4c172cafee53ede9bfd4f30b0e0c9888d5 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -54,6 +54,25 @@ using TypesF16F32F64CF64 = #error "Situation not handled yet" #endif +// Check that we can safely pass an input tuple's elements to a dot operation. +TEST_F(DotOperationTest, DotOfInputTupleElem) { + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle param; + auto param_data = CreateParameterAndTransferLiteral( + 0, + *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + "arg0", &builder, ¶m); + auto lhs = builder.GetTupleElement(param, 0); + auto rhs = builder.GetTupleElement(param, 1); + builder.Dot(lhs, rhs); + + ComputeAndCompareLiteral(&builder, + *Literal::CreateR2({{19, 22}, {43, 50}}), + {param_data.get()}); +} + template class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 4f354e6aefe70a51c09be1c0ca151af2bb9f0a2c..5f00c34002803553b9c17b4fce0abafda7369796 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3( - {{{1, 2}, {3, 4}, {5, 6}}, - {{7, 8}, {9, 10}, {11, 12}}}, - {0, 2, 1}, {2, 1, 2}, - {{{6, 5}}, {{12, 11}}}); + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, + {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); } template @@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -281,6 +278,15 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { class DynamicUpdateSliceTest : public ClientLibraryTestBase { protected: + template + void TestR0() { + // Disable algebraic simplifier, otherwise the op will be replaced by a + // constant. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); + RunR0(0, 123, {}, 123); + } + template void TestR1() { // Slice at dimension start. @@ -341,6 +347,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); } + template + void RunR0(int input_value_int, int update_value_int, + const std::vector slice_starts, int expected_value_int) { + Literal input_value = + std::move(*Literal::CreateR0(input_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_value = + std::move(*Literal::CreateR0(update_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_value = + std::move(*Literal::CreateR0(expected_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantLiteral(input_value); + auto update = builder.ConstantLiteral(update_value); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + } + template void RunR1(tensorflow::gtl::ArraySlice input_values_int, tensorflow::gtl::ArraySlice update_values_int, @@ -359,9 +394,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -390,9 +425,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -421,9 +456,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -474,13 +509,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } // Build dynamic slice computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer input parameter. - ComputationDataHandle input; + XlaOp input; std::unique_ptr input_data = CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. - ComputationDataHandle update; + XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1({index, 0, 0}); @@ -500,6 +535,11 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { TestR1(); @@ -672,7 +712,7 @@ void BM_DynamicSlice(int num_iters) { TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); - ComputationBuilder builder(client, "DynamicSlice"); + XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] auto input_literal = Literal::CreateR4( diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 6fe7737de7af349dca2931b52d62dbc03b14e0b3..b28fe0c15a89a1331698a29f70b966380bd3fcb9 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -71,8 +71,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #ifdef XLA_TEST_BACKEND_CPU // TODO(b/73141998): The vectorized Log implementation gives results outside // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1, 8315654}; + // floats expressed as a zero extended int64). + std::pair known_incorrect_range = {1, 8388608}; #else std::pair known_incorrect_range = {0, 0}; #endif diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4e2f19ade10794fd159ff89807d6ab34630dbb43..9db68ff7a6dcbd9204fb2b3a37734a9aaed35dfd 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" @@ -31,12 +33,16 @@ class GatherOperationTest : public HloTestBase { protected: void RunTest(const string& hlo_text, Literal* operand, Literal* gather_indices) { + RunTest(hlo_text, {operand, gather_indices}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, tools::Parse(hlo_text, config)); - EXPECT_TRUE( - RunAndCompare(std::move(module), {operand, gather_indices}, nullopt)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; @@ -259,5 +265,197 @@ ENTRY main { RunTest(hlo_text, operand.get(), gather_indices.get()); } +XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, + // we should get rid of the mask and check that backends produce the same + // value for OOB indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndex) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for negative + // accesses, we should get rid of the mask and check that backends produce the + // same value for negative indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), + output_window_dims={0,1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1,3,2} +} +)"; + std::unique_ptr operand = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ScalarResult) { + const char* hlo_text = R"( +HloModule ScalarResult + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[] gather(operand, index), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1} +} +)"; + std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { + const string hlo_text = R"( +HloModule ZeroSizedResult + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + ROOT gather = s32[0,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +class GatherClientLibraryTest : public ClientLibraryTestBase {}; + +// TODO(b/30671675): Asynchronous execution on stream is not yet supported on +// GPU and CPU_PARALLEL. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { + // We create this HLO, but using the ComputationBuilder API. + // + // ENTRY main { + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // ROOT gather = s32[2,3] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0}, + // gather_dims_to_operand_dims={0}, + // index_vector_dim=1, + // window_bounds={1, 3} + // } + + ComputationBuilder builder(client_, "gather_basic"); + + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = builder.Parameter(0, operand_shape, "operand"); + auto indices = builder.Parameter(1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(1); + dim_numbers.add_elided_window_dims(0); + dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.set_index_vector_dim(1); + builder.Gather(operand, indices, dim_numbers, {1, 3}); + + std::vector expected = {}; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, + client_->TransferToServer(*Literal::CreateR2( + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr indices_arg, + client_->TransferToServer(*Literal::CreateR1({0, 2}))); + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + *execution_options.add_device_handles() = devices[0]; + TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build()); + std::vector computation_instances = { + {computation, + {operand_arg.get(), indices_arg.get()}, + execution_options, + /*execution_profile=*/nullptr}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector> result_data, + client_->ExecuteParallel(computation_instances)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + client_->Transfer(*(result_data[0]))); + LiteralTestUtil::ExpectEqual( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index eded2077fce965ab1c729c610764afa2228ca128..cf971dd61b71ad329b20b0bb7c16166126562681 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase { metadata_.set_op_name("my_sum_op"); } - void BuildAddComputation(ComputationBuilder* builder) { + void BuildAddComputation(XlaBuilder* builder) { auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder->Add(x, y); @@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase { }; TEST_F(HloMetadataTest, MetadataPropagation) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); BuildAddComputation(&builder); builder.ClearOpMetadata(); @@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { } TEST_F(HloMetadataTest, MetadataClearing) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); // Some other pretend computation here. builder.ClearOpMetadata(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5f62c44f25dd62b563bd8ce02477bd741f264182..e574644dea7c1ba144ba87fbeb7f28cc52312e26 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -115,6 +115,13 @@ StatusOr> HloTestBase::Execute( return test_runner_.Execute(std::move(module), arguments); } +StatusOr> HloTestBase::ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments, + /*run_hlo_passes=*/false); +} + std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index e375f13a44e0618f2a498325859d928c1adb830e..3e8e2360bb3a87e127920cd222803c0f7b9161f4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -98,6 +98,12 @@ class HloTestBase : public ::testing::Test { std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); + // Same as above, except the module will be executed without running any HLO + // passes on it. + StatusOr> ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 641907acf260c099a5ac885c362d92a0b6d78a42..da4cf4ae0c31bc194cd2ec9b845df36afbde69b0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -64,7 +64,8 @@ HloModule& HloVerifiedTestBase::module() { return *module_; } -void HloVerifiedTestBase::ParseAndVerifyModule(const char* hlo_text) { +void HloVerifiedTestBase::ParseAndVerifyModule( + tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); VerifyModule(); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index c0cb12bc93f56a5cb5ebdac94488369331f0cea6..e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,7 +44,7 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(const char* hlo_text); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 50d7b5074d201d2292cf90224ef4cd37efdbb8d3..d24927d22b6534b46e711cd442f19a3e5cfcebdf 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -57,6 +57,11 @@ limitations under the License. namespace xla { namespace { +using FuncGeneratorForType = Computation (*)(PrimitiveType, + ComputationBuilder*); + +using FuncGenerator = Computation (*)(ComputationBuilder*); + class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { @@ -755,53 +760,57 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { } XLA_TEST_F(ReduceTest, VectorizedReduce_Add) { - RunVectorizedReduceTest(CreateScalarAddComputation, - [](float a, float b) { return a + b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) + - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); + RunVectorizedReduceTest( + static_cast(CreateScalarAddComputation), + [](float a, float b) { return a + b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) + + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); } XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) { - RunVectorizedReduceTest(CreateScalarMultiplyComputation, - [](float a, float b) { return a * b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) * - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); + RunVectorizedReduceTest( + static_cast(CreateScalarMultiplyComputation), + [](float a, float b) { return a * b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) * + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); } XLA_TEST_F(ReduceTest, VectorizedReduce_Max) { - RunVectorizedReduceTest(CreateScalarMaxComputation, - [](float a, float b) { return std::max(a, b); }, - [](int32 a, int32 b) { return std::max(a, b); }, - [](uint32 a, uint32 b) { return std::max(a, b); }, - std::numeric_limits::min(), - std::numeric_limits::min(), - std::numeric_limits::min()); + RunVectorizedReduceTest( + static_cast(CreateScalarMaxComputation), + [](float a, float b) { return std::max(a, b); }, + [](int32 a, int32 b) { return std::max(a, b); }, + [](uint32 a, uint32 b) { return std::max(a, b); }, + std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::min()); } XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { - RunVectorizedReduceTest(CreateScalarMinComputation, - [](float a, float b) { return std::min(a, b); }, - [](int32 a, int32 b) { return std::min(a, b); }, - [](uint32 a, uint32 b) { return std::min(a, b); }, - std::numeric_limits::max(), - std::numeric_limits::max(), - std::numeric_limits::max()); + RunVectorizedReduceTest( + static_cast(CreateScalarMinComputation), + [](float a, float b) { return std::min(a, b); }, + [](int32 a, int32 b) { return std::min(a, b); }, + [](uint32 a, uint32 b) { return std::min(a, b); }, + std::numeric_limits::max(), std::numeric_limits::max(), + std::numeric_limits::max()); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType( - CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); + static_cast(CreateScalarAndComputation), + [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType( - CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); + static_cast(CreateScalarOrComputation), + [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, @@ -884,5 +893,47 @@ XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); } +// Tests reductions with different initial values. There's no test macro that +// combines TYPED_TEST and TYPED_P, so we have to do it manually. +class ReduceInitializerTest : public ReduceTest { + protected: + template + void DoTest(T initializer, int num_elems) { + ComputationBuilder builder(client_, TestName()); + Computation max_fn = CreateScalarMaxComputation( + primitive_util::NativeToPrimitiveType(), &builder); + + auto init = builder.ConstantR0(initializer); + std::vector input_arr(num_elems, std::numeric_limits::lowest()); + auto input_literal = Literal::CreateR1(input_arr); + auto input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, + max_fn, {0}); + + ComputeAndCompareR0(&builder, initializer, {input_data.get()}); + } +}; + +XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest(42, 2); } + +XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest(42, 4096); } + +XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) { + DoTest(42, 4095); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) { + DoTest(0, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) { + DoTest(1, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { + DoTest(1234556789123, 1024); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 8b736f62f045bb913ac09add2f00d5edb0692d83..8dd24f1237136e2807cea8a261ead25f5c7adbb2 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -252,6 +252,48 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { DefaultErrorSpec()); } +// Tests the super windowing logic w.r.t handling prime number of windows in a +// major dimension with reduction. +TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { + Array4D input_array(15, 15, 4, 128); + input_array.FillRandom(2.f, 4.f); + + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + +TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { + Array4D input_array(19, 17, 8, 256); + input_array.FillWithMinorDimNum(); + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + // Tests a reduction function that is not a simple add/min/max/etc. XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); @@ -1021,6 +1063,15 @@ struct R2ReduceWindowTestData { /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, +// TODO(b/76025683): These tests fail on TPU. +#if defined(XLA_TEST_BACKEND_CPU) || defined(XLA_TEST_BACKEND_GPU) + {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, +#endif }; string R2ReduceWindowTestDataToString( @@ -1351,5 +1402,41 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[1,1]{1,0} parameter(0) + negate = f32[1,1]{1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { + const string& hlo_string = R"( +HloModule R3Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R3Window { + operand = f32[1,1,1]{2,1,0} parameter(0) + negate = f32[1,1,1]{2,1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index f7b04debd4f5c40a904e32c832b6fc384a03c33b..02272d60171c70896f44b0d6b96f176ea52e686f 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -207,9 +208,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -221,10 +222,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { // Splits a vector into a matrix. XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -241,9 +242,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -255,10 +256,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { // Transposes a 2-dimensional row vector to a column vector. XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = Literal::CreateFromArray(*simple); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -272,10 +273,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { // Transposes a 2-dimensional array. XLA_TEST_P(ReshapeTest, TransposeAsReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -291,11 +292,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. // -// Transposes a 0x4 array with ComputationBuilder::Trans. +// Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -306,10 +307,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { // Transposes a 2-dimensional array with ComputationBuilder::Trans. XLA_TEST_P(ReshapeTest, Transpose4x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -327,9 +328,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index fe36df160daacc4fdfbdb0b75f8304f91e1a4245..69fbe98bd63661322d37936c90a5fe3580efc2de 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -41,7 +41,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); @@ -54,7 +54,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); @@ -67,7 +67,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); @@ -77,7 +77,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { } XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); @@ -85,7 +85,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); @@ -93,7 +93,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); @@ -108,7 +108,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); @@ -126,7 +126,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { Array2D values(1, 4096); std::iota(values.data(), values.data() + 4096, 0.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); @@ -147,7 +147,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -159,7 +159,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); @@ -172,7 +172,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { /*strides=*/{{1, 1, 2, 1}}); auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), @@ -193,15 +193,18 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - std::vector input(spec.input_dim0); + // This can't be an std::vector, since you can't grab an ArraySlice of a + // vector. + tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR1(input); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); - std::vector expected; + // Ditto. + tensorflow::gtl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); @@ -211,6 +214,9 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +// A version of SliceR1Test used to label and disable 'large' tests +class SliceR1LargeTest : public SliceR1Test {}; + string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, @@ -230,6 +236,21 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } + + // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off @@ -237,12 +258,6 @@ INSTANTIATE_TEST_CASE_P( SliceR1TestInstantiation, SliceR1Test, ::testing::Values( -// TODO(b/69425338): This uses too much memory on GPU. -#ifndef XLA_TEST_BACKEND_GPU - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, -#endif R1Spec{10, 0, 0, 1}, R1Spec{10, 7, 7, 1}, R1Spec{10, 0, 5, 1}, @@ -278,6 +293,20 @@ INSTANTIATE_TEST_CASE_P( SliceR1TestDataToString ); +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU +INSTANTIATE_TEST_CASE_P( + SliceR1TestBigSlicesInstantiation, + SliceR1LargeTest, + ::testing::Values( + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1} + ), + SliceR1TestDataToString +); +#endif + INSTANTIATE_TEST_CASE_P( SliceStridedR1TestInstantiation, SliceR1Test, @@ -334,7 +363,7 @@ XLA_TEST_P(SliceR2Test, DoIt) { Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); @@ -424,7 +453,7 @@ class SliceR4Test : public ClientLibraryTestBase, values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = builder.Parameter(0, literal->shape(), "p0"); diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 978a669bcab720bddec5c4bcd0144810ba3c8477..be35ec6c6ee4c015755622b2dc9bb92e23af7c85 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0bc7df2a65b44a76f877b6513e6bf93b99fbc1a3..821432ef7dc7249d547a2d5f8868300388dc9d37 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -23,14 +23,14 @@ namespace xla { namespace { -template -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { +template +void PopulateWithRandomFloatingPointDataImpl(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); + std::uniform_real_distribution generator(1.0f, 1.125f); const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice indices) { @@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal, FloatT index_bias = static_cast(index_product % 113 - negative_bias) / static_cast(256.0f); - return (generator(*engine) - 1.0625) + index_bias; + return static_cast(generator(*engine) - 1.0625f) + index_bias; })); } +template +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + +template <> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> @@ -100,6 +112,9 @@ StatusOr> MakeFakeLiteralInternal( case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); break; + case F16: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; case F32: PopulateWithRandomFloatingPointData(literal.get(), engine); break; diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2029312f94a14bc81706368b9ecfc2727fd9fe4c..098be6d7aabe88d0deef600716229ddbd0bcae2f 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -20,11 +20,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +43,7 @@ class TupleTest : public ClientLibraryTestBase { // Tests a tuple-shaped constant. XLA_TEST_F(TupleTest, TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -53,13 +56,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { Literal::CreateR1(constant_vector).get(), Literal::CreateR2(constant_matrix).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests a tuple made of scalar constants. XLA_TEST_F(TupleTest, TupleScalarConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; @@ -67,13 +70,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), Literal::CreateR0(constant_scalar2).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -81,9 +84,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto result = builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); + builder.Tuple({builder.ConstantR0(constant_scalar), + builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); auto expected = Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), @@ -94,9 +97,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - auto result = builder.Tuple( + builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), @@ -106,15 +109,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.Tuple({}); + XlaBuilder builder(TestName()); + builder.Tuple({}); auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -122,23 +125,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto tuple_data = builder.Tuple( {builder.ConstantR1({}), builder.ConstantR2FromArray2D(Array2D(0, 101))}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto value = builder.ConstantR1({4.5f}); builder.GetTupleElement(value, 1); auto result_status = builder.Build(); @@ -151,7 +154,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { // Extracts both elements from a tuple with GetTupleElement and then adds them // together. XLA_TEST_F(TupleTest, AddTupleElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -163,22 +166,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) { auto matrix_element = builder.GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - auto result = builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + builder.Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } // Extracts both elements from a tuple and then puts them into a new tuple in // the opposite order. XLA_TEST_F(TupleTest, TupleGTEToTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -186,8 +189,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); + builder.Tuple({builder.GetTupleElement(tuple_data, 1), + builder.GetTupleElement(tuple_data, 0)}); auto expected = Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), Literal::CreateR1(constant_vector).get()}); @@ -195,8 +198,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle v1, v2; + XlaBuilder b(TestName()); + XlaOp v1, v2; for (bool direction : {false, true}) { std::unique_ptr v1_data = @@ -209,7 +212,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v2_gt = b.Gt(v2, v1); // true auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = Literal::MakeTuple({Literal::CreateR0(direction).get(), Literal::CreateR0(!direction).get()}); @@ -236,7 +239,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { // \ (tuple10)-- / // \ / \ / // -----(GTE 0)-- --(GTE 1)---------- - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -256,8 +259,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { auto addvectors = builder.Add(vector_from_01, vector_from_10); auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - auto result = builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + builder.Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -268,7 +271,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { // Tests a selection between tuples with "false" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -277,8 +280,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -313,7 +315,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { // Tests a selection between tuples with "true" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -322,8 +324,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(true), tuple12, tuple21); + builder.Select(builder.ConstantR0(true), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -332,7 +333,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { // Tests a selection between tuples but the final result is an element of the // tuple, not the whole tuple. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -343,7 +344,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto element = builder.GetTupleElement(select, 0); + builder.GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } @@ -367,7 +368,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { // / --(GTE 1)-- // / // (tuple 21) - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -383,8 +384,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); auto select2 = builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - auto result = builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + builder.Add(builder.GetTupleElement(select2, 0), + builder.GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } @@ -393,7 +394,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { // Similar to SelectBetweenTuples, but the constants are shared between the // input tuples. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -402,19 +403,18 @@ XLA_TEST_F(TupleTest, auto tuple12 = builder.Tuple({c1, c2}); auto tuple21 = builder.Tuple({c2, c1}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple( {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - auto outer_tuple = - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); auto expected_v1 = Literal::CreateR1({1.0, 2.0}); auto expected_s = Literal::CreateR0(42.0); @@ -428,7 +428,7 @@ XLA_TEST_F(TupleTest, NestedTuples) { } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {3}); Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); @@ -459,7 +459,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { } XLA_TEST_F(TupleTest, ComplexTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape c64r0 = ShapeUtil::MakeShape(C64, {}); Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); @@ -514,5 +514,33 @@ XLA_TEST_F(TupleTest, ComplexTuples) { error_spec_); } +class TupleHloTest : public HloTestBase {}; + +// Disabled on CPU parallel because that's broken and will be removed soon. +// Disabled on the interpreter because bitcast doesn't exist on the interpreter. +TEST_F(TupleHloTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_CPU_PARALLEL(BitcastAfterGTE))) { + const char* testcase = R"( + HloModule m + + ENTRY test { + name.1 = (f32[3]{0}) parameter(0) + get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0 + bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1) + copy = f32[1,3]{1,0} copy(bitcast) + ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + TF_ASSERT_OK_AND_ASSIGN(auto result, + ExecuteNoHloPasses(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 33d457c70bac84c2da10e3cf9302c2c952cf1bc2..89ce2ce797f979b8668fbdb172a4a3abc5922b9f 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Reduce(builder.ConstantR1(2, 1), builder.ConstantR0(0), CreateScalarAddComputation(S32, &builder), {0}); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Or(prev, builder.ConstantR0(true)); + builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true)); - auto result = builder.While(condition, body, init); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1({}); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1({}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.0001)); } @@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); builder.Tuple({result}); // Individual elements with increase by 1/8 each time through the loop, so @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(N); auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); @@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the body. // Add 1 to the iteration variable permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); @@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); std::vector expected = {6.f, 6.f, 6.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple({builder.ConstantR0(0), builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_predicate = Literal::CreateR0(true); @@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the body. // Add 1 to the iteration variable and set the other tuple element to a // constant. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); - auto result = - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), + builder.ConstantR0(7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR0(7)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR0(7); @@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - Computation body2; + XlaComputation body2; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); + {}); builder.Gt(builder.ConstantR0(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1({1}), @@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { builder.ConstantR0(100), ShapeUtil::MakeShape(S32, {5}))}, 0); - auto result = builder.Add(inc, prev); + builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - auto result = builder.While(build_condition(count), body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(build_condition(count), body, init); return builder.Build(); }; @@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - Computation inner_condition; + XlaComputation inner_condition; { - ComputationBuilder builder(client_, "inner_condition"); + XlaBuilder builder("inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0(7)); @@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the outer loop condition: // repeat while result < 30. - Computation outer_condition; + XlaComputation outer_condition; { - ComputationBuilder builder(client_, "outer_condition"); + XlaBuilder builder("outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0(30)); outer_condition = builder.Build().ConsumeValueOrDie(); @@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. - Computation inner_body; + XlaComputation inner_body; { - ComputationBuilder builder(client_, "inner_body"); + XlaBuilder builder("inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0(1), i); result = builder.Add(builder.ConstantR0(2), result); - auto output = builder.Tuple({i, result}); + builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. - Computation outer_body; + XlaComputation outer_body; { - ComputationBuilder builder(client_, "outer_body"); + XlaBuilder builder("outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0(0), prev}); auto result = builder.While(inner_condition, inner_body, init); - auto output = builder.GetTupleElement(result, 1); + builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(outer_condition, outer_body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition_callee; + XlaComputation condition_callee; { - ComputationBuilder builder(client_, "condition_callee"); + XlaBuilder builder("condition_callee"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.Call(condition_callee, {prev}); builder.GetTupleElement(result, 0); @@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_shape, "state"); builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_shape, "state"); auto indvar = builder.GetTupleElement(state, 0); auto input_0 = builder.GetTupleElement(state, 1); auto input_1 = builder.GetTupleElement(state, 2); auto output = builder.Tanh(builder.Dot(input_0, input_1)); auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + builder.Tuple({indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); auto init = builder.Tuple( {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); @@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) { // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(loop_limit)); @@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) { } // Create while body computation with unit loop increment. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) { auto starts = builder.ConstantR1({0, 0, 0}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. - ComputationBuilder builder(client, "while"); + XlaBuilder builder("while"); auto zero = builder.ConstantR0(0.0); auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); auto init = builder.Tuple({builder.ConstantR0(0), input}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 9ad2a1985331b80625dd0687ea052300bc99e440..ff3418a128eed82b730a6602d6e3faba4ad7be32 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -144,7 +145,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions())); + ExecutableBuildOptions().set_hlo_profile(true))); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -294,7 +295,8 @@ XLA_TEST_F(HloProfileTest, auto while_body_profile_start = std::find_if(profile_output_lines.begin(), profile_output_lines.end(), [](tensorflow::StringPiece s) { - return s.starts_with("Execution profile for body"); + return tensorflow::str_util::StartsWith( + s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.end()); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 92b2b1ee778f8b0f8104e7d7ff27a5c11db59768..0af40bc15a41f7c4ef6382b1a94412afe5741a86 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; @@ -25,7 +28,37 @@ GTEST_API_ int main(int argc, char** argv) { return 2; } + // If the --benchmarks flag is passed in then only run the benchmarks, not the + // tests. + for (int i = 1; i < argc; i++) { + tensorflow::StringPiece arg(argv[i]); + if (arg == "--benchmarks" || arg.starts_with("--benchmarks=")) { + const char* pattern = nullptr; + if (arg.starts_with("--benchmarks=")) { + pattern = argv[i] + strlen("--benchmarks="); + } else { + // Handle flag of the form '--benchmarks foo' (no '='). + if (i + 1 >= argc || + tensorflow::StringPiece(argv[i + 1]).starts_with("--")) { + LOG(ERROR) << "--benchmarks flag requires an argument."; + return 2; + } + pattern = argv[i + 1]; + } + // Unfortunately Google's internal benchmark infrastructure has a + // different API than Tensorflow's. +#if defined(PLATFORM_GOOGLE) + base::SetFlag(&FLAGS_benchmarks, pattern); + RunSpecifiedBenchmarks(); +#else + tensorflow::testing::Benchmark::Run(pattern); +#endif + return 0; + } + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 6fa4c48e11d1102367b21bc21d4734466495ef0e..44f874cd2ae8e6f65dc282b8675f195ec9c09415 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -38,7 +38,7 @@ namespace xla { StatusOr> TextLiteralReader::ReadPath( tensorflow::StringPiece path) { - CHECK(!path.ends_with(".gz")) + CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = @@ -115,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece value_string = pieces[1]; tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!coordinates_string.Consume("(")) { + if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); } diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 091fa0c3ec807a66449eca0bfbb141285b8eb532..0bc4045a5490319994b6cf24daf99fe856167507 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -75,6 +75,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -222,17 +223,3 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index 97aacf6b39f83978e732060817cd93ede81ca782..0fa4b98d0a41a1e7c681bb2302da3b752315867b 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -70,17 +70,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 863081d654390440aa6506bab4576b3cc5c1cbd1..adc8b1d620eb65fdca19072831360b71847abf9e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -894,7 +895,7 @@ class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index eda5effbb92db92c9317a956497a00c0ec15c27c..62a353ad09af009e4abf47664a5c5f7bd70a049e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -66,6 +67,7 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; + bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) @@ -122,16 +124,21 @@ StatusOr> ReplayComputation( std::unique_ptr result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { + execution_options.mutable_debug_options()->set_xla_hlo_profile(true); + } + if (opts.print_result) { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); + TF_ASSIGN_OR_RETURN( + result, client->ExecuteAndTransfer(computation, execute_arguments, + &execution_options, &profile)); } else { // If we're not printing the result, execute the computation but don't // bother retrieving the result. This can be a significant speedup. TF_RETURN_IF_ERROR(client ->Execute(computation, execute_arguments, - /*execution_options=*/nullptr, &profile) + &execution_options, &profile) .status()); } LOG(INFO) << "Execution took " @@ -191,6 +198,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag( + "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, + "Pass --xla_hlo_profile the last time we run the computation."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index dc4f7a1cb436183f5acfa360fb092795258b6a75..e43498e381b8e63543e2ddda08ca7c0df91817e4 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -243,8 +243,8 @@ string HumanReadableNumOps(double flops, double nanoseconds, static_cast(nano_flops * 1e9)); tensorflow::StringPiece sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (sp.ends_with("B") || // Ends in 'B', ignoring case - sp.ends_with("b")) { + if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case + tensorflow::str_util::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index ff99d3728d1c3b58fc94d3eb3de78be23407edc9..2da9f9ed6f40fcf5b2512f974519df0b355da10f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -519,6 +519,15 @@ int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } + +// Returns true if `x` fits in 32-bits. +template +bool IsInt32(T x) { + // Following conversion rules: "the value is unchanged if it can be + // represented in the destination type (and bit-field width); otherwise, the + // value is implementation-defined." + return static_cast(x) == x; +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index edf1b07af82b5d43fe67c6efdabdb0a9b4b1edea..5cb18113e5ba9c49809c4410d56ca7bb5a50dae5 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -299,6 +299,11 @@ message ComputationStatsRequest { DebugOptions debug_options = 2; } +message ComputationGraphStatsRequest { + HloModuleProto computation = 1; + DebugOptions debug_options = 2; +} + message ComputationStatsResponse { ComputationStats stats = 1; } @@ -355,6 +360,10 @@ message ExecuteParallelRequest { repeated ExecuteRequest requests = 1; } +message ExecuteGraphParallelRequest { + repeated ExecuteGraphRequest requests = 1; +} + message ExecuteResponse { GlobalDataHandle output = 1; ExecutionProfile profile = 2; diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index c2663c5e83352a1088166dc7581a0346c7b104a4..bf69144ad83c9b5f9a51d4c9e6fbfe61b5f16fb2 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -34,6 +34,7 @@ py_library( "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", @@ -73,11 +74,12 @@ py_library( "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/py2tf", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", @@ -108,10 +110,15 @@ py_library( "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", - ]) + if_not_windows([ - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", # unix dependency, need to fix code + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka", + ], + "//conditions:default": [], + }) + if_not_windows([ + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code - "//tensorflow/contrib/kafka", # has some linking issue on opensssl. ]), ) @@ -121,9 +128,7 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", - "//tensorflow/contrib/kafka:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", @@ -136,7 +141,13 @@ cc_library( "//tensorflow/contrib/text:all_kernels", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", - ]), + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_kernels", + ], + "//conditions:default": [], + }), ) cc_library( @@ -145,12 +156,10 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", - "//tensorflow/contrib/kafka:dataset_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", @@ -161,17 +170,11 @@ cc_library( "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", + ] + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], + "//conditions:default": [], + }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 669d611b01b585d91ab48921b7ba17703dd6bc98..1c5b00f92eace598dea5f035e4954b4b2de8da0e 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=g-import-not-at-top # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,6 +33,7 @@ from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import data from tensorflow.contrib import deprecated +from tensorflow.contrib import distribute from tensorflow.contrib import distributions from tensorflow.contrib import estimator from tensorflow.contrib import factorization @@ -85,8 +87,9 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -if os.name != 'nt': +if os.name != "nt": from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 8dff93b4f825277dcf0a64aa3b96bd809d36e1e9..62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -45,16 +45,3 @@ tf_py_test( "//tensorflow/python:state_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 6658f0d9c13f6db17b25354cde2593d57f104f17..8add2aacff1d64f1617cd24167c4c6c6706044da 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -38,16 +38,15 @@ def _flatten_tensors(tensors): shape: the original shape of each element of input tensors Raises: - ValueError: tensors are empty or non-isomorphic. + ValueError: tensors are empty or non-isomorphic or have unknown shape. """ if not tensors: raise ValueError("tensors cannot be empty") shape = tensors[0].shape for tensor in tensors: shape = shape.merge_with(tensor.shape) - if shape.ndims is None: - raise ValueError("At least one of the tensors in 'tensors' must have " - "statically known rank.") + if not shape.is_fully_defined(): + raise ValueError("Tensors must have statically known shape.") if len(shape) != 1: reshaped = [] for t in tensors: diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 47bab0a3670a90644972b2c961954a3036b8ecba..b3f5d92259df8475b205110dd3f0cee1cb5bde6f 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging class AllReduceTest(test_util.TensorFlowTestCase): + def testFlattenTensorsShapesDefined(self): + x = array_ops.placeholder(types_pb2.DT_FLOAT, [None]) + with self.assertRaisesRegexp(ValueError, + "must have statically known shape"): + ar._flatten_tensors([x, x]) + def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 4bff3c27d22c4550747a651a59909bdef80e8285..60306ebdc6cddb04e8807bfd495fa92a56e55ecd 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -38,20 +38,6 @@ cc_library( alwayslink = 1, ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # JAR with Java bindings to TF. android_library( name = "android_tensorflow_inference_java", diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 380a652435ad089f46f3ca80e4fd43097fd96e10..513d519eabbd54f46fde9ec0f004247c02277732 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" namespace tensorflow { namespace { @@ -228,9 +229,8 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { } string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { - string output(name); - StringPiece piece(output); - piece.Consume(prefix_); + StringPiece piece(name); + str_util::ConsumePrefix(&piece, prefix_); return piece.ToString(); } @@ -243,6 +243,11 @@ bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { return AAssetDir_getNextFileName(dir.get()) != NULL; } +Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern, + std::vector* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status AssetManagerFileSystem::NewWritableFile( const string& fname, std::unique_ptr* result) { return errors::Unimplemented("Asset storage is read only."); diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 665304b5eef1f8a3633c8c522259e20d744b1808..a87ff42ae217c429ecf5d2458b88b3431551ad97 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -66,6 +66,9 @@ class AssetManagerFileSystem : public FileSystem { Status DeleteDir(const string& d) override; Status RenameFile(const string& s, const string& t) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + private: string RemoveAssetPrefix(const string& name); diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index a115d1610e2334a6626f29674f3dd195e3a3c648..ecf1a103d2981f409a4598d762fb26100217f779 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -75,7 +75,6 @@ target_link_libraries(tensorflow_inference include_directories( ${PREBUILT_DIR}/proto ${PREBUILT_DIR}/protobuf/include - ${PREBUILT_DIR}/nsync/public ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen ${TENSORFLOW_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/autograph/BUILD similarity index 75% rename from tensorflow/contrib/py2tf/BUILD rename to tensorflow/contrib/autograph/BUILD index d91220f6ddb859ff52d4e5853948cb667981009b..30dd846893c30b9205972bd5216cc1871ab03d76 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -15,16 +15,16 @@ filegroup( ) py_library( - name = "py2tf", + name = "autograph", srcs = [ "__init__.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/impl", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/contrib/py2tf/README.md b/tensorflow/contrib/autograph/README.md similarity index 87% rename from tensorflow/contrib/py2tf/README.md rename to tensorflow/contrib/autograph/README.md index cd50675ad57316b9c749c137e6acd30b91c10073..7e84f237dc9a83098f142a54c48cf5b6ba35aaaa 100644 --- a/tensorflow/contrib/py2tf/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,4 +1,4 @@ -# Py2TF +# Autograph A compiler for generating TensorFlow numeric and control flow ops from Python code. diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/autograph/__init__.py similarity index 59% rename from tensorflow/contrib/py2tf/__init__.py rename to tensorflow/contrib/autograph/__init__.py index 6531183cb59af774299eb767cce111d2ec6f32b4..a39f44b21aa0ddf683b30c18bbe15a43262f7db2 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Py2TF compiles Python code into equivalent TensorFlow code. +"""Autograph compiles Python code into equivalent TensorFlow code. Equivalent here means that they have the same effect when executed. """ @@ -21,18 +21,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.impl.api import convert -from tensorflow.contrib.py2tf.impl.api import converted_call -from tensorflow.contrib.py2tf.impl.api import graph_ready -from tensorflow.contrib.py2tf.impl.api import to_code -from tensorflow.contrib.py2tf.impl.api import to_graph -from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl.api import convert +from tensorflow.contrib.autograph.impl.api import converted_call +from tensorflow.contrib.autograph.impl.api import do_not_convert +from tensorflow.contrib.autograph.impl.api import RunMode +from tensorflow.contrib.autograph.impl.api import to_code +from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'to_graph', 'to_code', 'convert', 'graph_ready', 'converted_call', 'utils', - 'PyFlowParseError' + 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', + 'to_code', 'to_graph', 'AutographParseError' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD similarity index 90% rename from tensorflow/contrib/py2tf/converters/BUILD rename to tensorflow/contrib/autograph/converters/BUILD index 4bb6f76019739fc3b5bf4bf52e302a698693db5a..c5a0dc10959ccb64e090292794bcd0b4fd2dbbd2 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -49,9 +49,9 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -61,6 +61,7 @@ py_test( name = "asserts_test", srcs = ["asserts_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", "//tensorflow/python:client_testlib", @@ -81,7 +82,7 @@ py_test( name = "builtin_functions_test", srcs = ["builtin_functions_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = ["no_windows"], deps = [ ":test_lib", "//tensorflow/python:client_testlib", @@ -90,12 +91,13 @@ py_test( py_test( name = "call_trees_test", + size = "large", srcs = ["call_trees_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/impl", + "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], ) @@ -145,7 +147,7 @@ py_test( srcs = ["name_scopes_test.py"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -201,7 +203,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -212,7 +214,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/__init__.py rename to tensorflow/contrib/autograph/converters/__init__.py index ca10896ee5c6c23d9b20ff23add9945de68e5bf9..e4e8eda42f655e204310eaa9defdd5c90bf06e15 100644 --- a/tensorflow/contrib/py2tf/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Code converters used by Py2TF.""" +"""Code converters used by Autograph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/py2tf/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/asserts.py rename to tensorflow/contrib/autograph/converters/asserts.py index 5b9b8e772bed82df2429fd6cb94dbf7b565e22b3..f011a97ade94f2979486ef6329673a0160dd9bac 100644 --- a/tensorflow/contrib/py2tf/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class AssertsTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/asserts_test.py rename to tensorflow/contrib/autograph/converters/asserts_test.py index 6611f2777a93a7e819c8becfa06a09b27f4e6aaf..cc913febe8d0f411588af69b87ec52ce58f4469c 100644 --- a/tensorflow/contrib/py2tf/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py similarity index 92% rename from tensorflow/contrib/py2tf/converters/break_statements.py rename to tensorflow/contrib/autograph/converters/break_statements.py index bfb709c5e32c6f19dc0fd109df61ece925d701a3..48026bccab5ff3474e9d54e365dad4a589b931fc 100644 --- a/tensorflow/contrib/py2tf/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -20,14 +20,14 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class BreakCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" + """Canonicalizes break statements into additional conditionals.""" def __init__(self, context): super(BreakCanonicalizationTransformer, self).__init__(context) diff --git a/tensorflow/contrib/py2tf/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/break_statements_test.py rename to tensorflow/contrib/autograph/converters/break_statements_test.py index 095fcdff07d44ecc6b9bb7f8d3e2c7c43df72a02..dd4914a022f57b3bb4a19ec132f311f12269fa9e 100644 --- a/tensorflow/contrib/py2tf/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py similarity index 92% rename from tensorflow/contrib/py2tf/converters/builtin_functions.py rename to tensorflow/contrib/autograph/converters/builtin_functions.py index f1129ef153e6be6cbcbbf4bab63c4fe32ec77147..0349ce29ceb097fbebc36a0378b9072750772416 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class BuiltinFunctionTransformer(transformer.Base): @@ -38,13 +38,13 @@ class BuiltinFunctionTransformer(transformer.Base): def _convert_builtin(self, node): template = """ - py2tf_utils.dynamic_builtin(func, args) + autograph_utils.dynamic_builtin(func, args) """ return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - py2tf_utils.dynamic_print(args) + autograph_utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py similarity index 96% rename from tensorflow/contrib/py2tf/converters/builtin_functions_test.py rename to tensorflow/contrib/autograph/converters/builtin_functions_test.py index eb60a1d8ae2b56907df8f3ffafe7604883cfc2a9..ac7e756c47c31816ad34a7ea6926917712afa6c3 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -22,8 +22,8 @@ import sys import six -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py similarity index 82% rename from tensorflow/contrib/py2tf/converters/call_trees.py rename to tensorflow/contrib/autograph/converters/call_trees.py index ca8726f9160d106ebd82e01e399e65fb77b02aab..61f6bfd7e733fc3e2e0bea35a955509c39d57bc9 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -22,18 +22,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple import types import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import inspect_utils -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +class FunctionInfo(namedtuple('FunctionInfo', ('dtype',))): + pass + + +# TODO(mdan): Move this to config.py. +KNOWN_NUMPY_FUNCTIONS = { + ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'), +} + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -106,6 +118,12 @@ class CallTreeTransformer(transformer.Base): def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" + # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. + module_name = fqn[0] + for mod in self.uncompiled_modules: + if module_name.startswith(mod[0] + '.'): + return False + for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False @@ -179,11 +197,27 @@ class CallTreeTransformer(transformer.Base): return node def _wrap_to_py_func_no_return(self, node): - # TODO(mdan): Properly handle varargs, kwargs, etc. + # TODO(mdan): Properly handle varargs, etc. + template = """ + autograph_utils.wrap_py_func(func, None, (args,), kwargs, True) + """ + return templates.replace( + template, + func=node.func, + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _wrap_to_py_func_single_return(self, node, dtype): + # TODO(mdan): Properly handle varargs, etc. template = """ - py2tf_utils.wrap_py_func(func, None, (original_args,), True) + autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False) """ - return templates.replace(template, func=node.func, original_args=node.args) + return templates.replace_as_expression( + template, + func=node.func, + dtype=parser.parse_expression(dtype), + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) def _insert_dynamic_conversion(self, node): """Inlines a dynamic conversion for a dynamic function.""" @@ -204,10 +238,9 @@ class CallTreeTransformer(transformer.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - py2tf_api.converted_call(func, True, False, {}, original_args) + autograph_api.converted_call(func, True, False, {}, args) """ - call_expr = templates.replace( - template, func=node.func, original_args=node.args) + call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value # TODO(mdan): Improve the template mechanism to better support this. new_call.keywords = node.keywords @@ -248,10 +281,19 @@ class CallTreeTransformer(transformer.Base): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') + if anno.hasanno(node.func, 'fqn'): + target_fqn = anno.getanno(node.func, 'fqn') + else: + target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) + elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: + # TODO(mdan): Should we replace these with equivalent TF ops instead? + node = self._wrap_to_py_func_single_return( + node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: - raise NotImplementedError('py_func with return values') + raise NotImplementedError( + 'py_func with return values (unknown function)') else: if self.context.recursive: node = self._insert_dynamic_conversion(node) diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py similarity index 85% rename from tensorflow/contrib/py2tf/converters/call_trees_test.py rename to tensorflow/contrib/autograph/converters/call_trees_test.py index d482a9ef7897388839bbf8f9e4bfc5839d42b2d7..c666dcb73b232ce443898cfe3359f74605af98f2 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import converter_test_base +import numpy as np + +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -105,6 +109,20 @@ class CallTreesTest(converter_test_base.TestCase): sess.run(sess.graph.get_operations()[0]) self.assertEquals('bar', a.foo) + def test_py_func_wrap_known_function(self): + + def test_fn(): + return np.random.binomial(2, 0.5) + + node = self.parse_and_analyze(test_fn, {'np': np}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node, dtypes.int64) as result: + result.np = np + with self.test_session() as sess: + self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) + self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) + def test_uncompiled_modules(self): def test_fn(a): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py similarity index 94% rename from tensorflow/contrib/py2tf/converters/continue_statements.py rename to tensorflow/contrib/autograph/converters/continue_statements.py index 4069a678b118b56b59d2e5491bb80cf52efd8143..4299a8a9d59715d032222c47794bbb4393f34ce6 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class ContinueCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/continue_statements_test.py rename to tensorflow/contrib/autograph/converters/continue_statements_test.py index a598dcd1aed29478b7e3fe27e3c1b20010247dd9..bcbb316d7459aa5a25bb0bd128cd6e359a393288 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/control_flow.py rename to tensorflow/contrib/autograph/converters/control_flow.py index 762c26f0c77e13c077761ceec41cb29db9149a35..49d932026ffa9e79e7ddc640f7d3deaec0f4b8a6 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -82,7 +82,7 @@ class ControlFlowTransformer(transformer.Base): def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ - results = py2tf_utils.run_cond(test, body_name, orelse_name) + results = autograph_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, @@ -92,7 +92,7 @@ class ControlFlowTransformer(transformer.Base): orelse_name=orelse_name) else: template = """ - py2tf_utils.run_cond(test, body_name, orelse_name) + autograph_utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) @@ -204,7 +204,7 @@ class ControlFlowTransformer(transformer.Base): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) + state_ast_tuple = autograph_utils.run_while(test_name, body_name, [state]) """ node = templates.replace( template, diff --git a/tensorflow/contrib/py2tf/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/control_flow_test.py rename to tensorflow/contrib/autograph/converters/control_flow_test.py index b785b284a7fb7a0257551326c88b44a341b295ba..86fed51f27bee07f772633f3928ac5263bf57652 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py similarity index 85% rename from tensorflow/contrib/py2tf/converters/converter_test_base.py rename to tensorflow/contrib/autograph/converters/converter_test_base.py index 8c08c5492a4b10d4abb0ec3b19b39d5b17e41a0a..3ea2cfd668270a69427c24cdf1bbf11d32d66ebe 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/converters/converter_test_base.py @@ -21,15 +21,15 @@ from __future__ import print_function import contextlib import imp -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import pretty_printer -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -75,8 +75,8 @@ class TestCase(test.TestCase): try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_mod('fake_tf', *symbols) - result.py2tf_utils = utils - result.py2tf_api = self.make_fake_mod('fake_api', converted_call) + result.autograph_utils = utils + result.autograph_api = self.make_fake_mod('fake_api', converted_call) yield result except Exception: # pylint:disable=broad-except if source is None: diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py similarity index 96% rename from tensorflow/contrib/py2tf/converters/decorators.py rename to tensorflow/contrib/autograph/converters/decorators.py index 68bf241ef33292f0581ccb3c44f313f853c92ba7..92445f31746cf94856ea43893f99a2ba60355fb5 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,8 +24,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import pretty_printer class DecoratorsTransformer(gast.NodeTransformer): diff --git a/tensorflow/contrib/py2tf/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/decorators_test.py rename to tensorflow/contrib/autograph/converters/decorators_test.py index c75e5461746f27d14a54b7ac06e7f77d868372c8..e67ab1cd6a15ceb66fe75140419c7abca9653ae4 100644 --- a/tensorflow/contrib/py2tf/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,9 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.pyct import compiler from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/for_loops.py b/tensorflow/contrib/autograph/converters/for_loops.py similarity index 67% rename from tensorflow/contrib/py2tf/converters/for_loops.py rename to tensorflow/contrib/autograph/converters/for_loops.py index 4297c1cf2a3632e097973280cc985fc48da64475..4999c47bdc79ec0ea352472cfd3e97b94ebc7cce 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops.py +++ b/tensorflow/contrib/autograph/converters/for_loops.py @@ -22,10 +22,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class ForLoopCanonicalizationTransformer(transformer.Base): @@ -38,19 +38,19 @@ class ForLoopCanonicalizationTransformer(transformer.Base): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) i_var = self.context.namer.new_symbol('i', body_scope.referenced) - n_var = self.context.namer.new_symbol('n', body_scope.referenced) - iterated_var = self.context.namer.new_symbol('iterated', - body_scope.referenced) + smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter', + body_scope.referenced) + cont_var = self.context.namer.new_symbol('cont', body_scope.referenced) # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 - iterated = loop_iter - n = len(iterated) - while i < n and extra_cond: - target = iterated[i] + smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter) + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) + while cont and extra_cond: body i += 1 + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) """ return templates.replace( template, @@ -58,18 +58,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base): target=node.target, body=node.body, i=i_var, - n=n_var, - iterated=iterated_var, + smart_loop_iter=smart_loop_iter_var, + cont=cont_var, extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ i = 0 - iterated = loop_iter - n = len(iterated) - while i < n: - target = iterated[i] + smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter) + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) + while cont: body i += 1 + cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter) """ repl = templates.replace( template, @@ -77,8 +77,8 @@ class ForLoopCanonicalizationTransformer(transformer.Base): target=node.target, body=node.body, i=i_var, - n=n_var, - iterated=iterated_var) + smart_loop_iter=smart_loop_iter_var, + cont=cont_var) return repl def visit_Continue(self, node): diff --git a/tensorflow/contrib/py2tf/converters/for_loops_test.py b/tensorflow/contrib/autograph/converters/for_loops_test.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/for_loops_test.py rename to tensorflow/contrib/autograph/converters/for_loops_test.py index b6e3e8c8d8d4960977e2b72b56a3fab8329ad2a7..943f52de55a3629fdb18e6188e42269a4cb06275 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops_test.py +++ b/tensorflow/contrib/autograph/converters/for_loops_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import for_loops +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import for_loops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py similarity index 88% rename from tensorflow/contrib/py2tf/converters/ifexp.py rename to tensorflow/contrib/autograph/converters/ifexp.py index 5fd6f348af0df81a6ff35745da603bd431130e20..bb0c0a36a7827e5c73e0fa67f09aa4f54d497a2c 100644 --- a/tensorflow/contrib/py2tf/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class IfExp(transformer.Base): @@ -27,7 +27,7 @@ class IfExp(transformer.Base): def visit_IfExp(self, node): template = """ - py2tf_utils.run_cond(test, lambda: body, lambda: orelse) + autograph_utils.run_cond(test, lambda: (body,), lambda: (orelse,)) """ desugared_ifexp = templates.replace_as_expression( template, test=node.test, body=node.body, orelse=node.orelse) diff --git a/tensorflow/contrib/py2tf/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py similarity index 86% rename from tensorflow/contrib/py2tf/converters/ifexp_test.py rename to tensorflow/contrib/autograph/converters/ifexp_test.py index 9c357ef35b550833bcb79d39f0bdbc6d758d31a5..ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e 100644 --- a/tensorflow/contrib/py2tf/converters/ifexp_test.py +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import ifexp +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import ifexp from tensorflow.python.platform import test @@ -38,7 +38,7 @@ class IfExpTest(converter_test_base.TestCase): return 1 if x else 0 with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils for x in [0, 1]: self.assertEqual(test_fn(x), result.test_fn(x)) @@ -52,7 +52,7 @@ class IfExpTest(converter_test_base.TestCase): return y with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils result.f = f for x in [-2, 2]: self.assertEqual(test_fn(x), result.test_fn(x)) @@ -63,7 +63,7 @@ class IfExpTest(converter_test_base.TestCase): return x * x if x > 0 else x with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils for x in [-2, 2]: self.assertEqual(test_fn(x), result.test_fn(x)) @@ -73,7 +73,7 @@ class IfExpTest(converter_test_base.TestCase): return x * x if x > 0 else x if x else 1 with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils for x in [-2, 0, 2]: self.assertEqual(test_fn(x), result.test_fn(x)) @@ -85,7 +85,7 @@ class IfExpTest(converter_test_base.TestCase): return -x with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils for x in [-2, 2, 5]: self.assertEqual(test_fn(x), result.test_fn(x)) @@ -97,7 +97,7 @@ class IfExpTest(converter_test_base.TestCase): return x with self.compiled_fn(test_fn) as result: - result.py2tf_util = utils + result.autograph_util = utils for x in [-2, 2, 5]: self.assertEqual(test_fn(x), result.test_fn(x)) diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension.py rename to tensorflow/contrib/autograph/converters/list_comprehension.py index e8744831100e4852919b5cd1253b74acea4d790d..d7f292015164e047d054c5d1fb0b391e960bb73d 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,9 +31,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class ListCompCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehension_test.py index 025fac11e41e6771fbb9b80ff3da70dc3ceec73e..4758671f5ec83c26cfa54be0ef68f5f564094f6c 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import list_comprehension from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/lists.py rename to tensorflow/contrib/autograph/converters/lists.py index 06e1dad8f4d652da78ed39309f5b40598e368ea6..234a0a7487d5fc9e068acf4a19af3bac84f4737e 100644 --- a/tensorflow/contrib/py2tf/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -32,9 +32,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.framework import dtypes @@ -61,17 +61,20 @@ class ListTransformer(transformer.Base): return templates.replace_as_expression(template, dtype_name=dtype_name) def _pre_populated_list(self, node): - raise NotImplementedError() + raise NotImplementedError('pre-populated lists') def visit_Expr(self, node): node = self.generic_visit(node) if isinstance(node.value, gast.Call): call_node = node.value + + if not anno.hasanno(call_node.func, anno.Basic.QN): + return node qn = anno.getanno(call_node.func, anno.Basic.QN) if qn.qn[-1] == 'append' and (len(call_node.args) == 1): template = """ - target = py2tf_utils.dynamic_list_append(target, element) + target = autograph_utils.dynamic_list_append(target, element) """ node = templates.replace( template, diff --git a/tensorflow/contrib/py2tf/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/lists_test.py rename to tensorflow/contrib/autograph/converters/lists_test.py index 671a1cc7b1225061a00731596c536c4403e0bdff..749ba14347314f975c5a6e1111133336e2f5c5e6 100644 --- a/tensorflow/contrib/py2tf/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import lists +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import lists from tensorflow.python.framework import dtypes from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py similarity index 71% rename from tensorflow/contrib/py2tf/converters/logical_expressions.py rename to tensorflow/contrib/autograph/converters/logical_expressions.py index 10192e6a036c4a44aa1e6f1b4a390579bd703373..3a795a315a3c2aa08ac1577a204102755b6e849c 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -23,9 +23,10 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer # TODO(mdan): Properly extrack boolean ops according to lazy eval rules. @@ -44,19 +45,20 @@ class LogicalExpressionTransformer(transformer.Base): def __init__(self, context): super(LogicalExpressionTransformer, self).__init__(context) # TODO(mdan): Look into replacing with bitwise operators instead. + # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { - gast.And: 'logical_and', - gast.Eq: 'equal', - gast.Gt: 'greater', - gast.GtE: 'greater_equal', - gast.Lt: 'less', - gast.LtE: 'less_equal', - gast.Not: 'logical_not', - gast.NotEq: 'not_equal', - gast.Or: 'logical_or', - gast.USub: 'negative', - gast.Is: 'py2tf_utils.dynamic_is', - gast.IsNot: 'py2tf_utils.dynamic_is_not' + gast.And: 'tf.logical_and', + gast.Eq: 'tf.equal', + gast.Gt: 'tf.greater', + gast.GtE: 'tf.greater_equal', + gast.Lt: 'tf.less', + gast.LtE: 'tf.less_equal', + gast.Not: 'tf.logical_not', + gast.NotEq: 'tf.not_equal', + gast.Or: 'tf.logical_or', + gast.USub: 'tf.negative', + gast.Is: 'autograph_utils.dynamic_is', + gast.IsNot: 'autograph_utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): @@ -70,27 +72,19 @@ class LogicalExpressionTransformer(transformer.Base): '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"') - def _matching_tf_op(self, operator): + def _matching_func(self, operator): op_type = type(operator) mapped_op = self.op_mapping.get(op_type) if not mapped_op: raise NotImplementedError('operator %s is not yet supported' % op_type) return mapped_op - def _inline_tf_op(self, op_name, args): - if 'py2tf_utils' in op_name: - # TODO(alexbw): explicitly spelling out the attribute function name - # until fix for issue highlighted in cl/188931581 lands. - template = """ - py2tf_utils.op_name(args) + def _as_function(self, func_name, args): + template = """ + func_name(args) """ - op_name = op_name.replace('py2tf_utils.', '') - else: - template = """ - tf.op_name(args) - """ replacement = templates.replace_as_expression( - template, op_name=op_name, args=args) + template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement @@ -104,14 +98,14 @@ class LogicalExpressionTransformer(transformer.Base): # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) - binary_comparison = self._inline_tf_op(self._matching_tf_op(op), - (left, right)) + binary_comparison = self._as_function( + self._matching_func(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) - op_tree = self._inline_tf_op('logical_and', - (binary_comparison, op_tree)) + op_tree = self._as_function('tf.logical_and', + (binary_comparison, op_tree)) else: op_tree = binary_comparison left = right @@ -120,7 +114,7 @@ class LogicalExpressionTransformer(transformer.Base): def visit_UnaryOp(self, node): node = self.generic_visit(node) - return self._inline_tf_op(self._matching_tf_op(node.op), node.operand) + return self._as_function(self._matching_func(node.op), node.operand) def visit_BoolOp(self, node): node = self.generic_visit(node) @@ -130,7 +124,7 @@ class LogicalExpressionTransformer(transformer.Base): while node_values: left = node_values.pop() self._expect_simple_symbol(left) - right = self._inline_tf_op(self._matching_tf_op(node.op), (left, right)) + right = self._as_function(self._matching_func(node.op), (left, right)) return right diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py similarity index 92% rename from tensorflow/contrib/py2tf/converters/logical_expressions_test.py rename to tensorflow/contrib/autograph/converters/logical_expressions_test.py index eb28c309a429f2267cc1ae1f6f65a8cde0ad91b8..2814060c4d831e4dddacb3dcbcbe1db42160db20 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.python.ops import math_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/name_scopes.py rename to tensorflow/contrib/autograph/converters/name_scopes.py index c702823fcf047fcad3254318bd323d2b8fddd700..2a3f474360e94635470bf9581222e4c79f46b7a1 100644 --- a/tensorflow/contrib/py2tf/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -21,8 +21,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class FunctionNameScopeTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/name_scopes_test.py rename to tensorflow/contrib/autograph/converters/name_scopes_test.py index a8ca341602ee5f06dbb812643a58794339d98afe..61e5db2af826d0c2238f1af0f3240411596f7429 100644 --- a/tensorflow/contrib/py2tf/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import name_scopes +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py similarity index 91% rename from tensorflow/contrib/py2tf/converters/side_effect_guards.py rename to tensorflow/contrib/autograph/converters/side_effect_guards.py index 30976b3ec6db5a6607023ac804d9d54cfb296190..1c1293d2c411b51b563ac3965284a48725ed3278 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,12 +36,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base): [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ - with py2tf_utils.control_dependency_on_returns(call): - aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) + with autograph_utils.control_dependency_on_returns(call): + aliased_guarded_args = autograph_utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, @@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base): alias_map = {} template = """ - with py2tf_utils.control_dependency_on_returns(call): + with autograph_utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py similarity index 97% rename from tensorflow/contrib/py2tf/converters/side_effect_guards_test.py rename to tensorflow/contrib/autograph/converters/side_effect_guards_test.py index 463db2e770213ba9636d2537b095a77dece5d8f6..ce0ce33243a1352107eb8121050ee76474869809 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/py2tf/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py similarity index 96% rename from tensorflow/contrib/py2tf/converters/single_return.py rename to tensorflow/contrib/autograph/converters/single_return.py index 1194b98f5ebeffa79a41fc3b32aa79ffd8cc407b..bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7 100644 --- a/tensorflow/contrib/py2tf/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Move this logic into transformer_base. @@ -232,7 +232,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): def visit_Return(self, node): if self.cant_return: raise ValueError( - 'Pyflow currently does not support `return` statements in loops. ' + '`return` statements are not supported in loops. ' 'Try assigning to a variable in the while loop, and returning ' 'outside of the loop') diff --git a/tensorflow/contrib/py2tf/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py similarity index 97% rename from tensorflow/contrib/py2tf/converters/single_return_test.py rename to tensorflow/contrib/autograph/converters/single_return_test.py index 2ea7a9d6d3e25c8dafd8f211994c8fe99bd0e781..d483005a09537ea8227814f65aa7e6402c853f60 100644 --- a/tensorflow/contrib/py2tf/converters/single_return_test.py +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import single_return +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import single_return from tensorflow.python.framework.ops import name_scope from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d62390494b78c415212ba91ac914cdfee324f971 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -0,0 +1,1919 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Dev Summit 2018 - Autograph", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python2", + "display_name": "Python 2" + } + }, + "cells": [ + { + "metadata": { + "id": "g7nGs4mzVUHP", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Experimental: TF Autograph\n", + "**TensorFlow Dev Summit, 2018.**\n", + "\n", + "This interactive notebook demonstrates **autograph**, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.\n", + "\n", + "**Note: this is pre-alpha software!** The notebook works best with Python 2, for now.\n", + "\n", + "> ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", + "\n", + "### Table of Contents\n", + "1. _Write Eager code that is fast and scalable._\n", + "2. _Case study: complex control flow._\n", + "3. _Case study: training MNIST with Keras._\n", + "4. _Case study: building an RNN._" + ] + }, + { + "metadata": { + "id": "uFcgBENZqkB2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Install TensorFlow; note that Colab notebooks run remotely, on virtual\n", + "# instances provided by Google.\n", + "!pip install -U -q tf-nightly" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Pa2qpEmoVOGe", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import os\n", + "import time\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import six\n", + "\n", + "from google.colab import widgets" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "ZVKfj5ttVkqz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 1. Write Eager code that is fast and scalable\n", + "\n", + "TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.\n", + "\n", + "Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." + ] + }, + { + "metadata": { + "id": "snaZRFdWd9ym", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "For example, autograph can convert a function like this:" + ] + }, + { + "metadata": { + "id": "9__n8cSIeDnD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def g(x):\n", + " if x > 0:\n", + " x = x * x\n", + " else:\n", + " x = 0\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "gq0eQcuReHET", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "... into a TF graph-building function:" + ] + }, + { + "metadata": { + "id": "sELSn599ePUF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 413 + }, + "outputId": "bb0c7216-1ca3-4da1-d1fb-589902cdcd1a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737505, + "user_tz": 240, + "elapsed": 243, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "print(autograph.to_code(g))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "from tensorflow.contrib.autograph.impl import api as autograph_api\n", + "from tensorflow.contrib.autograph import utils as autograph_utils\n", + "\n", + "def tf__g(x):\n", + " with tf.name_scope('g'):\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_1, = x,\n", + " x_1 = 0\n", + " return x_1,\n", + " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "j74n-8hEe6dk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can then use the converted function as you would any regular TF op -- you can pass `Tensor` arguments and it will return `Tensor`s:" + ] + }, + { + "metadata": { + "id": "AkVaY0-dfEbH", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "4ffe3757-c44d-424c-c2a8-7ddc973bfcce", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737841, + "user_tz": 240, + "elapsed": 257, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "tf_g = autograph.to_graph(g)\n", + "\n", + "with tf.Graph().as_default(): \n", + "\n", + " g_ops = tf_g(tf.constant(9))\n", + "\n", + " with tf.Session() as sess:\n", + " tf_g_result = sess.run(g_ops)\n", + "\n", + " print('g(9) = %s' % g(9))\n", + " print('tf_g(9) = %s' % tf_g_result)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "g(9) = 81\n", + "tf_g(9) = 81\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "trrHQBM1VnD0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 2. Case study: complex control flow\n", + "\n", + "Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph.\n", + "Autograph will automatically convert most Python control flow statements into their correct graph equivalent.\n", + " " + ] + }, + { + "metadata": { + "id": "u0YG3DPgZxoW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:" + ] + }, + { + "metadata": { + "id": "xJYDzOcrZ8pI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "6c244ee4-b141-4ad6-eefa-cfffa71f33c6", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738402, + "user_tz": 240, + "elapsed": 483, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def sum_even(numbers):\n", + " s = 0\n", + " for n in numbers:\n", + " if n % 2 > 0:\n", + " continue\n", + " s += n\n", + " return s\n", + "\n", + "\n", + "tf_sum_even = autograph.to_graph(sum_even)\n", + "\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " result = sess.run(tf_sum_even(tf.constant([10, 12, 15, 20])))\n", + "\n", + " print('Sum of even numbers: %s' % result)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(sum_even))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Sum of even numbers: 42\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "_YXo4KOcbKrn", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Try replacing the `continue` in the above code with `break` -- Autograph supports that as well!" + ] + }, + { + "metadata": { + "id": "xHmC0rBIavW_", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs." + ] + }, + { + "metadata": { + "id": "UEHWGpBXbS7g", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " + ] + }, + { + "metadata": { + "id": "qUU57xlEbauI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "add3db4a-2077-4dd5-f7a7-a5b5a4529c26", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738697, + "user_tz": 240, + "elapsed": 253, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(x):\n", + " assert x != 0, 'Do not pass zero!'\n", + " return x * x\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " try:\n", + " print(sess.run(tf_f(tf.constant(0))))\n", + " except tf.errors.InvalidArgumentError as e:\n", + " print('Got error message: %s' % e.message)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Got error message: assertion failed: [Do not pass zero!]\n", + "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "w5hBZaVJbck4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can also use `print` functions in-graph:" + ] + }, + { + "metadata": { + "id": "6NdzRKLEboRv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "fb82dfc3-790f-4127-87f6-361805be9e9b", + "executionInfo": { + "status": "ok", + "timestamp": 1522345739013, + "user_tz": 240, + "elapsed": 247, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_sign(n):\n", + " if n >= 0:\n", + " print(n, 'is positive!')\n", + " else:\n", + " print(n, 'is negative!')\n", + " return n\n", + "\n", + "\n", + "tf_print_sign = autograph.to_graph(print_sign)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " sess.run(tf_print_sign(tf.constant(1)))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_sign))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1 is positive!\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "9u_Z3i3AivLA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We can convert lists to TensorArray, so appending to lists also works, with a few modifications:" + ] + }, + { + "metadata": { + "id": "MjhCQJVuiTNR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "dc320b87-595b-4392-d29c-994486fd8a0a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345744470, + "user_tz": 240, + "elapsed": 5391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(n):\n", + " numbers = []\n", + " # We ask you to tell us about the element dtype.\n", + " autograph.utils.set_element_type(numbers, tf.int32)\n", + " for i in range(n):\n", + " numbers.append(i)\n", + " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + "\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(tf_f(tf.constant(5))))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[0 1 2 3 4]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "UdG8ZFrkTAF2", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "And all of these functionalities, and more, can be composed into more complicated code:\n" + ] + }, + { + "metadata": { + "id": "DVs6wt8NKaGQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "cellView": "code", + "outputId": "0a4b8d08-8f65-4bbc-85ba-dc4c60563519", + "executionInfo": { + "status": "ok", + "timestamp": 1522345745186, + "user_tz": 240, + "elapsed": 658, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_primes(n):\n", + " \"\"\"Returns all the prime numbers less than n.\"\"\"\n", + " assert n > 0\n", + " \n", + " primes = []\n", + " autograph.utils.set_element_type(primes, tf.int32)\n", + " for i in range(2, n):\n", + " is_prime = True\n", + " for k in range(2, i):\n", + " if i % k == 0:\n", + " is_prime = False\n", + " break\n", + " if not is_prime:\n", + " continue\n", + " primes.append(i)\n", + " all_primes = primes.stack()\n", + "\n", + " print('The prime numbers less than', n, 'are:')\n", + " print(all_primes)\n", + " return tf.no_op()\n", + "\n", + " \n", + "tf_print_primes = autograph.to_graph(print_primes)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " n = tf.constant(50)\n", + " sess.run(tf_print_primes(n))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_primes))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "The prime numbers less than 50 are:\n", + "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "JQ8kQT99VqDk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 3. Case study: training MNIST with Keras\n", + "\n", + "As we've seen, writing control flow in Autograph is easy. So running a training loop in graph should be easy as well!\n", + "\n", + "Here, we show an example of such a training loop for a simple Keras model that trains on MNIST." + ] + }, + { + "metadata": { + "id": "0CrtGWgwuLJr", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import gzip\n", + "import shutil\n", + "\n", + "from six.moves import urllib\n", + "\n", + "\n", + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zu1U9Nqir6L", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "First, we'll define a small three-layer neural network using the Keras API" + ] + }, + { + "metadata": { + "id": "x_MU13boiok2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')])\n", + " model.build()\n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wuqg3H8mi0Xj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's connect the model definition (here abbreviated as `m`) to a loss function, so that we can train our model." + ] + }, + { + "metadata": { + "id": "W51sfbONiz_5", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "035tNWQki9tr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient." + ] + }, + { + "metadata": { + "id": "CsAD0ajbi9iZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PcVRIacKjSwb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "These are some utility functions to download data and generate batches for training" + ] + }, + { + "metadata": { + "id": "RVw57HdTjPzi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n", + "\n", + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zEJH5XNjgFz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.\n", + "\n", + "In the loop inside this function, we'll grab a batch of data, apply an update to the weights of our model to improve its performance, and then record its current training loss and accuracy. Every so often, we'll log some information about training as well." + ] + }, + { + "metadata": { + "id": "UUI0566FjZPx", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + " train_losses = []\n", + " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n", + " test_losses = []\n", + " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n", + " train_accuracies = []\n", + " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n", + " tf.float32)\n", + " test_accuracies = []\n", + " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n", + " tf.float32)\n", + " i = tf.constant(0)\n", + " while i < hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + " if i % (hp.max_steps // 10) == 0:\n", + " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n", + " step_test_loss, 'train accuracy:', step_train_accuracy,\n", + " 'test accuracy:', step_test_accuracy)\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + " i += 1\n", + " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", + " test_accuracies.stack())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "cYiUQ1ppkHzk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Everything is ready to go, let's train the model and plot its performance!" + ] + }, + { + "metadata": { + "id": "K1m8TwOKjdNd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 988 + }, + "outputId": "f9d3eef3-5bea-45c1-ddf9-4edee73e4436", + "executionInfo": { + "status": "ok", + "timestamp": 1522345800262, + "user_tz": 240, + "elapsed": 52391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=500,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n", + " test_accuracies])\n", + " plt.title('MNIST train/test losses')\n", + " plt.plot(train_losses, label='train loss')\n", + " plt.plot(test_losses, label='test loss')\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Loss')\n", + " plt.show()\n", + " plt.title('MNIST train/test accuracies')\n", + " plt.plot(train_accuracies, label='train accuracy')\n", + " plt.plot(test_accuracies, label='test accuracy')\n", + " plt.legend(loc='lower right')\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Accuracy')\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/train-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/train-labels-idx1-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/t10k-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/t10k-labels-idx1-ubyte.gz\n", + "Step 0 train loss: 2.244329 test loss: 2.2499208 train accuracy: 0.12 test accuracy: 0.161\n", + "Step 50 train loss: 0.64771986 test loss: 0.56013924 train accuracy: 0.82 test accuracy: 0.836\n", + "Step 100 train loss: 0.49011207 test loss: 0.42143965 train accuracy: 0.84 test accuracy: 0.879\n", + "Step 150 train loss: 0.3768609 test loss: 0.39319593 train accuracy: 0.88 test accuracy: 0.883\n", + "Step 200 train loss: 0.36007702 test loss: 0.37089333 train accuracy: 0.9 test accuracy: 0.881\n", + "Step 250 train loss: 0.182115 test loss: 0.28543878 train accuracy: 0.94 test accuracy: 0.915\n", + "Step 300 train loss: 0.2119576 test loss: 0.22305593 train accuracy: 0.92 test accuracy: 0.93\n", + "Step 350 train loss: 0.12932214 test loss: 0.29057172 train accuracy: 0.96 test accuracy: 0.906\n", + "Step 400 train loss: 0.22937602 test loss: 0.2200287 train accuracy: 0.92 test accuracy: 0.925\n", + "Step 450 train loss: 0.23444137 test loss: 0.19857481 train accuracy: 0.94 test accuracy: 0.94\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3XmAFNW9Pvynlt5mYdhmQMHggnGN\nS9zCD0ElKug1edUY9ZoQTYze3GuiRk1uYjRqRHNj4n5NrhKjiUYlbihGQFRUFDSoKIvgICAO6+xL\n711V5/2jlq7qZaZnpnumZ3g+/zjTXV1dXSP91PecU+dIQggBIiIiGjLkwT4AIiIi6h2GNxER0RDD\n8CYiIhpiGN5ERERDDMObiIhoiGF4ExERDTEMb6JeOOigg3DllVdmPf6rX/0KBx10kGe766+/3rPN\ne++9h9mzZwMAtm3bhkMPPdR57osvvsCPfvQjzJw5EzNnzsTZZ5+NV199FQBw0003YdasWZg1axYO\nO+wwnHLKKc7v4XDY8x7JZBLz58/v9edavXo1Lr300oK2XbBgAebMmdPn97J19/rZs2fjhRde6PO+\niYY7hjdRL3366aee0Ewmk1izZk3WditXrsQnn3xS0D6vu+46TJs2DYsXL8bixYtxyy234LrrrsPO\nnTtxyy23YNGiRVi0aBHGjRuH3//+987vVVVVnv188sknfQrUI444Ag8//HBB2y5fvhxTpkzp83vZ\n+vt6oj0Zw5uol0444QQsWbLE+f3tt9/GV77ylaztrrnmGtx+++0F7bO+vh5HHnmk8/uRRx6JxYsX\nY/z48QUfV3NzM3784x/jo48+wkUXXQTAbAF48MEHMXPmTOi6jlWrVuHcc8/FrFmzcOaZZ2L58uUA\nzFaB0047DQBw//334ze/+Q2uuOIKfP3rX8d5552HxsZG533ee+89HHzwwVnv9cEHH+Bb3/oWTjvt\nNJx//vloaGgAAOzevRsXX3wxzjzzTJx66qm4++67cx5rPu+99x7OOecczJo1C9/+9redC6Vc++3u\ncSEE/vd//xczZ87EKaecgjlz5kDXdQDAwoULcdZZZ+GMM87AN77xDbz33nsFn3eiwcDwJuqlM844\nAy+99JLz+z//+U/MmjUr53ZCCCxatKjHfU6fPh1XXnkl/va3v2HTpk0AgHHjxkGSpIKPa+zYsbjm\nmmtw1FFH4YknnnAeF0Jg8eLFUBQFv/71r3HppZdi0aJFuPzyy3HTTTfl3NeiRYtw/fXX49VXX8WY\nMWPw7LPPAgA2bdqE2tpaTJgwwfNe4XAY//mf/4lrrrkGS5Yswfe+9z1cddVVAIBHH30Uxx13HF5+\n+WUsWLAADQ0NMAwj57FmikQiuOqqq3DDDTdg0aJF+OEPf4jrrrsOhmHk3G9jY2Pex1944QUsWrQI\nzzzzDJYsWYKGhgY8+eSTAIBbbrkFDz74IBYuXIibbroJr7/+esHnnWgwMLyJeun444/Hxo0b0dLS\nglgshlWrVmHKlCk5t73++uvxhz/8AYlEott9/v73v8d3vvMdLFiwAGeddRZmzJjhBEt/nXzyyc7P\n8+fPxxlnnAEAOOaYY5zqONOxxx6LCRMmQJIkHHLIIdi5cycAYMWKFTk/6wcffIBx48Zh6tSpAICz\nzjoLX3zxBXbs2IExY8bg7bffxvvvvw+/34+77roLdXV1BR376tWrMX78eBxzzDEAgJkzZ6KtrQ3b\nt2/Pu998jy9duhTf+ta3UF1dDVVV8e1vfxuvvPIKAGDMmDF46qmnsH37dhx77LH45S9/WdjJJRok\n6mAfANFQoygKTj/9dCxcuBCjR4/GiSeeCFXN/U/psMMOw3HHHYdHHnkERx99dN59BgIBXHrppbj0\n0kvR2dmJRYsW4fbbb8fEiRMxbdq0fh3vyJEjnZ8XLFiAv/3tb4hEIjAMA/mWNqiurnZ+VhTFaV5+\n5513cMkll2Rt39nZiYaGBk8LhN/vR2trKy655BIYhoFbbrkFjY2N+M53voOf/OQnBR17a2srRowY\nkXVsLS0tefeb7/Guri48/PDDmDdvHgBA13WMHj0aAPCnP/0Jf/rTn3Duuedir732wvXXX4/jjz++\noGMkGgwMb6I+OPPMM3H33Xdj1KhRPfbZ/vSnP8W5556LiRMn5ny+tbUV69evd6rWESNG4Pzzz8ey\nZctQX1/f7/C27d69GzfccAOefvppHHLIIfj8888xc+bMgl+vaRrWrFmT8yKkrq4O+++/P5577rmc\nr7388stx+eWXY8uWLbjsssucSronY8aMQXt7u/O7EAIdHR0YM2YMVFXNud+pU6fmfLyurg4zZszA\nd7/73az3+dKXvoTf/va3MAwD8+fPx7XXXotly5YVeGaIBh6bzYn64Oijj0ZjYyM2btzYY4VWV1eH\n73znO7j//vtzPh+Px3HllVd6wmLr1q34+OOPceyxx/bquFRVRTgczllRt7a2oqKiAvvvvz80TXMq\n0EgkUtC+V69ejYMOOgh+vz/rvY488kg0NTXh448/BgA0NDTgZz/7GYQQ+PWvf4133nkHgBmSY8eO\nhSRJ3R6r7YgjjkBzczNWrVoFwBxfMH78eEycODHvfvM9/vWvfx0vvPACYrEYAOCpp57C888/j9bW\nVnz/+99HOByGLMs48sgjezXWgGgwsPIm6gNJknDaaachFotBlnu+Bv7BD36Ap59+Oudze++9N/70\npz/hvvvuw5w5cyCEQFVVFX75y196RqAX4phjjsEf/vAHTJs2DW+++abnuYMPPhjTp0/HzJkzMWbM\nGPziF7/Ahx9+iNmzZ+O///u/e9y3fYtYvve67777cOuttyISicDn8+Gqq66CJEm48MIL8etf/xq3\n3norhBCYMWMGpkyZgh07dnheryhK1ntWVFTgnnvuwa233opoNIrRo0fjrrvu6na/I0eOzPk4AGzc\nuBHnnHMOADPYb7vtNowePRrTpk3Dt771LSiKAp/Ph9tuu61X551ooElcz5uIiGhoYbM5ERHREMPw\nJiIiGmIY3kREREMMw5uIiGiIYXgTERENMUPmVrGmpq6i7m/UqAq0tUWLus89Ec9j//Ec9h/PYXHw\nPPZfsc9hbW11zsf32MpbVbPvKaXe43nsP57D/uM5LA6ex/4bqHO4x4Y3ERHRUMXwJiIiGmIY3kRE\nREMMw5uIiGiIYXgTERENMQxvIiKiIYbhTURENMQwvImIaNh6443XCt723nvvxI4d23vc7sMP38cN\nN/y8P4fVbwxvIiIalnbu3IFXX11c8PZXXXUt9t57QgmPqHiGzPSoREREvXHXXb/D+vXr8Mgjc2EY\nBnbs2I6dO3fgnnv+iN/+9jdoampELBbDD35wOaZOnYYf//hyXHPNz7F06WuIRML44out2L59G668\n8lpMmTI153u89toSzJv3dyiKgoMOOgS33XYL6us34M47fwefzwe/349bbvktdu7cnvVYdXXuqU8L\nsceGd0c4gfc3NOLYg+sG+1CIiIa9f7z+GVZuaCzqPo87uA7nz5ic9/l///fZeO65f+D7378MDz/8\nIDQthT/+8c9oa2vF8cd/DWeccRa2b9+GG2/8BaZOneZ5bWPjbvzhD/fh3XeX44UXns0Z3tFoFA89\n9AAeeeQJVFRU4Oc//yneffddvPzyyzjnnPMwa9a/4YMPVqK1tQUvv7wg6zGGdx9ceecbaO2M46ZL\njsOk8X0/gURENDQccshhAIDq6hFYv34dXnzxOUiSjM7OjqxtjzjiKABAXV0dwuFwzv01NHyBiRO/\nhIqKCgDA0Ucfg/Xr1+PEE0/CH/7wP2ho+AJf//ppmDRp35yP9cceGd5b23YiPOFNSMnD0dwRZ3gT\nEZXY+TMmd1slDwSfzwcAWLJkETo7O/HAA39GZ2cnfvjD2VnbKkp6gREhRM79SZL3OU1LQZJCOPbY\n4/HnP/8Ny5cvw5w5N+PHP74652Nf/eqxff4se2R4f7ztCyjVbTBG70RLZ3ywD4eIiEpAlmXoup71\neHt7O/baa2/Isow333wdqVSqT/vfZ59J2LbtC0SjEVRUVGLVqg9x1VU/xrPPzsOUKSfi9NPPgBAC\n9fUbsGXLpqzHGN69dPykA7G4CZArO9DSwfAmIhqOJk3aD59+ugH33XcnKiurnMdPPnkGfvGLa/DJ\nJ2vxb//2TdTV1eGRR+b2ev+hUAhXXHEVrr32J5AkGUcccRSOPfZY7NzZghtv/AWqqqrg8/lw/fU3\nob7+06zH+kMS+doDykxTU1dR93fjit+ipTOCQyLn4yfnHlHUfe9Jamuri/632dPwHPYfz2Fx8Dz2\nX7HPYW1t7m7dPfY+7y+P2Q+SL4mmcOtgHwoREVGv7LHhPbFmPACgLdk2yEdCRETUO3tseI8JjQIA\nxBFGPKkN8tEQEREVbs8N74rRAADJH+egNSIiGlL22PAeW2FW3pI/xtvFiIhoSNljw3uME96svImI\naGjZY8M75AvCLwcg+eNoZuVNRDQs9WZJUNtHH32ItjbvnUjlsAyo2x4b3gAwMlDDypuIaJjq7ZKg\ntn/+88Ws8C43e+QMa7a6ijFojDWiqSt7UnoiIhra3EuCXnDBRbj99lvQ1dUFXddx9dU/w+TJB+Lx\nxx/Fm28uhSzLmDp1Gg455FAsW/YGtmzZjDlz7sD48eOz9pu5DOjVV1/nLANaWRkCIJdkGVC3PTy8\nxwItQKfWPtiHQkQ0rD332UtY1bimqPs8uu4rOHfyWXmfdy8J+uijf8YJJ/w/fOMbZ2PLls24994/\n4J57/oinnnoc8+cvgqIomD//WRx33NcwefKXcc01P88Z3LmWAf3ww/fx1ltLcc4552H27AuxaNHr\nJVkG1G2PDu/a0FgAQAysvImIhrM1a1ajvb0Nixe/DABIJMzu0pNP/jquvvq/cNpps3D66bN63E+u\nZUDr6zc4S362tOzClCknlWQZULc9OrzrKszwTildMISALEmDfERERMPTuZPP6rZKLjWfT8VPf/oz\nHH64dy2L6677JbZu/Ryvv74EP/nJf+Chh/7a7X5yLQMaCAScJT/XrFlZsmVA3fboAWt25Y1gFNE4\nZ1kjIhpO3EuCHnro4XjrrTcAAFu2bMZTTz2OcDiMRx6Zi0mT9sX3v38ZqqtrEI1G8i4lCniXAQWA\nVas+xEEHHYpnn52Hzs4OfPOb38QFF1yE+voNzmOnn36G81ix7NGV96hgDSQhQw5EEYmnUBXyDfYh\nERFRkbiXBP3hD3+E2267Gf/1Xz+EYRi4+urrUFVVhfb2Nlx22fcQClXg8MOPwIgRNTjqqK/ihhv+\nG7/97Z3Yf/8DPPvMtQzokUcehVgsihtv/AVGjaoBIJdkGVC3PXZJUHvZtp++fgvicQM/O+pa7L/3\niKK+x56ASwj2H89h//EcFgfPY/9xSdABEpCCkNQUIvHUYB8KERFRQfb48A4qIUiqhs4oJ2ohIqKh\nYY8P7wo1BABoj4YH+UiIiIgKs8eHd6XPvFevIxEZ5CMhIiIqzB4f3iMClQCAjjjDm4iIhoY9PrxH\nVZgj+Xa1tw3ykRARERVmjw/v0RXm7WE7OjrQHk4M8tEQERH1bI8P70qfOWBNUpNYvallkI+GiIio\nZwxvn9nnDSWFpvbY4B4MERFRAUo6Peodd9yBDz74AJqm4T/+4z9w+umnO88tX74cd911FxRFwfTp\n03HFFVeU8lDysm8Vk9QUWjvZbE5EROWvZOH97rvvYuPGjZg3bx7a2tpwzjnneMJ7zpw5ePjhhzFu\n3Dh897vfxcyZMzF58uRSHU5eITVo/qBoaOviRC1ERFT+Shbexx13HI44wlx6bcSIEYjFYtB1HYqi\noKGhATU1Ndhrr70AACeddBJWrFgxKOHtV/wAAJ9foK2NlTcREZW/koW3oijOYuXPPPMMpk+fDkVR\nAABNTU0YPXq0s+3o0aPR0NDQ7f5GjaqAqipFPcba2mqM1M3K2+8XaI8kMXZsFSSu690r+SbOp8Lx\nHPYfz2Fx8Dz230Ccw5IvCfrqq6/imWeewV/+8pd+7aetLVqkIzLZK78IISBLMiTFQCKpY+u2NlQG\nuTRoobgKUf/xHPYfz2Fx8Dz237BYVWzZsmX4v//7P8ydOxfV1ekDqKurQ3Nzs/P77t27UVdXV8pD\nyUuSJPhlP2TVXHi9jYPWiIiozJUsvLu6unDHHXfgwQcfxMiRIz3PTZw4EeFwGNu2bYOmaVi6dCmm\nTp1aqkPpkV/xAbIZ3h2R5KAdBxERUSFK1mz+8ssvo62tDVdffbXz2AknnICDDjoIp512Gm6++WZc\ne+21AIAzzzwT++23X6kOpUd+xY9kyhxpHo5xXW8iIipvJQvvCy64ABdccEHe54877jjMmzevVG/f\nKwHFjw6YS4IyvImIqNzt8TOsAYBf9kMXZmhHGN5ERFTmGN4w+7wNGIBksPImIqKyx/BGeqIWyDrC\ncYY3ERGVN4Y3zD5vAGZ4s/ImIqIyx/CG2ecNAKrPYJ83ERGVPYY3rPu8AYRCEitvIiIqewxvpPu8\nQyEgHNMG+WiIiIi6x/BGus87GABiCQ26YQzyEREREeXH8Ea68g5YS3tH46y+iYiofDG8Afhls89b\nVc2KO57UB/NwiIiIusXwRrryllUBwGw6JyIiKlcMbwABJQAAzrKgrLyJiKicMbwBhFQzvCXVrLjj\nSVbeRERUvhjeAIKKNVJNNu/xjiVYeRMRUflieAMIWpW3IZkVd4yVNxERlTGGN4CQGgIAGJJZecdZ\neRMRURljeAMIWgPWdCQBsM+biIjKG8MbgCqrUCQFmhXe7PMmIqJyxvAGIEkSgmoAKWGFNytvIiIq\nYwxvS1AJImkkAABxTtJCRERljOFtCaoBJHQrvDlJCxERlTGGtyWkBpHQk1BkNpsTEVF5Y3hbgkoQ\nAgKBoOCtYkREVNYY3hZ7opZgCIiyz5uIiMoYw9sSVM0pUitCQCSWGuSjISIiyo/hbQlZ85sHQwJJ\nzUAixaZzIiIqTwxvi115B4IGAFbfRERUvhjeFrvP2+c3wzvM8CYiojLF8LbYzeYqw5uIiMocw9ti\nV96yz+zrZngTEVG5YnhbglblLavmbWIMbyIiKlcMb4tdeUNheBMRUXljeFtC1mhzQzJDm+FNRETl\niuFtCTK8iYhoiGB4W+w+b3tNb85vTkRE5YrhbfHJKmRJdtb0TunGIB8RERFRbgxviyRJCClBZ01v\nneFNRERliuHtElQDiGlxKLLEypuIiMoWw9slqAYR1xJQFRmaJgb7cIiIiHJieLsErWZzRQE0Vt5E\nRFSmGN4uITUAAQHVLxjeRERUthjeLj7Fb/5XNRjeRERUthjeLn7ZBwCQVYGUzj5vIiIqTwxvF5+s\nAgAU1YCm9b/ybutK4MEX16G5I9bvfREREdkY3i4+xay8FaU4fd5PvFqP9z7Zjb8u3NDvfREREdkY\n3i4+u9ncZ0ArQrN5PKl7/ktERFQMDG8Xu89bUQwYQsAw2O9NRETlh+HtYjebS4rZZM5Z1oiIqBwx\nvF2c0eayGdq8XYyIiMoRw9vF7vO2K+9i9HsTEREVG8PbxWk2tyvvItwuRkREVGwlDe/6+nqceuqp\nePzxx7OemzFjBi666CLMnj0bs2fPxu7du0t5KAWxK2/I5ujwfjebC1buRERUfGqpdhyNRnHrrbdi\nypQpebeZO3cuKisrS3UIvebPCG8OWCMionJUssrb7/dj7ty5qKurK9VbFF1Ws3mxwlsqzm6IiIiA\nElbeqqpCVbvf/U033YTt27fjmGOOwbXXXgtJGtyUs6dHFZLdbM5mbyIiKj8lC++eXHnllZg2bRpq\nampwxRVXYPHixZg1a1be7UeNqoCqKkU9htraas/vcf9IAIBqLi6Gqqpg1ja94fObp9enKv3aT7kb\nzp9toPAc9h/PYXHwPPbfQJzDQQvvs88+2/l5+vTpqK+v7za829qiRX3/2tpqNDV1eR4Lx1IAgJSW\nBAA0t4TRVBPo83ukkpq1Pz3rvYaLXOeReofnsP94DouD57H/in0O810IDMqtYl1dXbj00kuRTJoh\nuXLlShx44IGDcSge9mhzQ+KANSIiKl8lq7zXrl2L3/3ud9i+fTtUVcXixYsxY8YMTJw4Eaeddhqm\nT5+OCy64AIFAAIceemi3VfdA8St2n7dZMevs8yYiojJUsvA+/PDD8dhjj+V9/uKLL8bFF19cqrfv\nE6fyBitvIiIqX5xhzUWRFEiQYMCsvDnDGhERlSOGt4skSfApPqfy7uk+7x3hXXjsk38grsUH4vCI\niIgADOJo83Lll33QhTVKvIc+7/s+eghdyTDGVdTi9H1PGYjDIyIiYuWdKagEkDQSAAC9m8p7W2MY\nXckwACBpJAfk2IiIiACGd5bairGIGREEv/oqtic3593ulfcbnJ8lzn9KREQDiOGdYXyFORe7pGpY\nrb2af0N3i/ogT+tKRER7FoZ3hnGV6YVUVPjzbifAe8CJiGhwMLwzjK+oTf8iCquoZTabExHRAGJ4\nZxhfOc75OYEINEPLvaGn8GZ4ExHRwGF4Z6j2V+EHX/4h9I4xgCTQGm/r8TXs8iYiooHE8M5h/5pJ\nMLpGAQCaYq05t/GMV8tTebNXnIiISoHhnYOqSBApc7BaLJV7KVLhSmbeKkZERAOJ4Z2DqsiAYU4+\nl8g7AYsnvYmIiAYMwzsHVZEhDAUAkNBzh3chzeZERESlwPDOQVUkQDfDO5knvHuD4U5ERMXE8M5B\nkiQo1pot21o6cm8kvNsTERENFIZ3Hgp8AICV9TuwsyWS9TxHkhMR0WBheOdhhzdkHZ2R7pvO2SxO\nREQDieGdhyqlwzsX4bpXzBD5lw4lIiIqNoZ3HnZ4S0qe6VFd3EFORERUagzvPHyKz5yIRdaR1Lqv\nrA2w8iYiooHD8M5DlRXAUCApOpKp7KZzd7HNZnMiIhpIDO88fKp1r7esI5nqofJmszkREQ0ghnce\n5ixrKiRFQ0LLUXm7f2blTUREA4jhnYeqyAVX3nqe8GZBTkREpcDwzkOWJXN+c1lHIpljxLn7VjEO\nWCMiogHE8M7DMIQ5YE0WSGiprOe9zeY9lNicw4WIiIqI4Z2HYQhAN+c3j2mJ7rdlnzcREQ0ghnce\nuiGcZUFjqXj2Bp5bxdi5TUREA6eg8F67di2WLl0KALj77rtx8cUX4/333y/pgQ023RAQyQAAIKKH\nu92Wfd5ERDSQCgrvOXPmYL/99sP777+PNWvW4MYbb8R9991X6mMbVIYhIBIhAEBMdGU9z1vFiIho\nsBQU3oFAAPvuuy9ee+01nH/++Zg8eTJkeXi3uJuVdzfh7VmYhM3mREQ0cApK4FgshoULF+LVV1/F\niSeeiPb2dnR2dpb62AaVIQREMggASCLXet7pwK7f1pZzxDkXLCEiolIoKLyvueYaLFiwAD/96U9R\nVVWFxx57DJdcckmJD21w6a5m86Scq8873VS+uy2CpvZY9hZ2djPDiYioiNRCNvra176Gww8/HFVV\nVWhubsaUKVPw1a9+tdTHNqgMwwAMFUJToSvRrOfdlTckkQ5q9zZW5c0KnIiIiqmgyvvWW2/FwoUL\n0d7ejgsvvBCPP/44br755hIf2uD60rhqAIBIhKCrkawAzry3O3ezub1taY6RiIj2TAWF9yeffIJv\nf/vbWLhwIc455xzcc8892Lp1a6mPbVBdcsbB+N7Mg+DTqwFZR0cyo49fSieyxMqbiIgGUEHhbYfP\nG2+8gRkzZgAAkslk6Y6qDFQGfTj56AkIiBEAgMZok+d54b63WxI5VyGxA53ZTURExVRQeO+33344\n88wzEYlEcMghh2D+/Pmoqakp9bGVhZAwP+fOsDe8vROziJwBzcqbiIhKoaABa3PmzEF9fT0OOOAA\nAMDkyZNxxx13lPTAykW1MgotALZ37fY8nll557rXO+lrARSJfd5ERFRUBYV3PB7H66+/jnvvvReS\nJOGoo47C5MmTS31sZWGkbzSAXM3mmaPNvQm9qf1ztO31OvwVYyBaTy71YRIR0R6koGbzG2+8EeFw\nGBdeeCHOP/98NDc344Ybbij1sZWFmmAVhACi1uIkH3zaiBfe3gJkNJvruje869s2AQCUmhb2eRMR\nUVEVVHk3Nzfjrrvucn4/5ZRTMHv27JIdVDmpCKpAVIJm6ACAB55fCwA4cLLrukcS6EqGcdt7D+Oc\nyf+GQ8cchNZ4KwBApHzs8yYioqIqeHrUWCw9g1g0GkUi0f0a18NFZVAFhAxN1z2Pp9y/S8DqjlXY\nEdmFBz5+GADQEm8DAIhkiH3eRERUVAVV3hdccAHOOOMMHH744QCAdevW4aqrrirpgZWLiqAPEBI0\n4Q3vpK65fhMQGQndaoe3prLyJiKioioovM877zxMnToV69atgyRJuPHGG/HYY4+V+tjKgll5S9AN\n74xqKS0d3lLGgDUhhFN5QzYY3kREVFQFhTcA7LXXXthrr72c31evXl2SAyo3duWti+6azYUnoBN6\n0pk+VZJ1DlgjIqKi6vOi3HtKNVkZVCGEbC5U4pIy3GEunAFtABDX4+mnFH2POVdERDQw+hzekiQV\n8zjKVkVQBSBBhze8tYzKO2GkB/DFtHR4S7LOAWtERFRU3Tabn3TSSTlDWgiBtra2kh1UOamw+ryF\n8PZdp3QNAdd2yTzhzT5vIiIqtm7D+4knnhio4yhbiixDEjIMpKC5J2KR3TOsGXkrb7DPm4iIiqzb\n8J4wYcJAHUdZkyUJAgZSWrqpXJJdt4pJQMod3qmoazsDBpjeRERUPH3u8y5EfX09Tj31VDz++ONZ\nzy1fvhznnXceLrjgAjzwwAOlPIx+kyADEEhprn5v1R3eAkmRDu+2RIfn9ULSQEREVCwlC+9oNIpb\nb70VU6ZMyfn8nDlzcP/99+PJJ5/EO++8g88++6xUh9JviiRDSAaSrvD2VN4QSBnp9c2d8DbM0ysk\n721mRERE/VGy8Pb7/Zg7dy7q6uqynmtoaEBNTQ322msvyLKMk046CStWrCjVofSbLCkABOJJVwgr\n3klaUsIV3vF28wctCAAQYOVNRETFU7LwVlUVwWAw53NNTU0YPXq08/vo0aPR1NSUc9tyoMgyJFmg\nI5JuGpdUb+Wtwd1sboV3yhwf+2w3AAAgAElEQVSPzsqbiIiKqeAZ1gbbqFEVUFWlqPusra0uaDtV\nMU+TUFzXOlblLTQVkj/puQu8PWk1m1vhDVkv+L2GouH82QYKz2H/8RwWB89j/w3EORyU8K6rq0Nz\nc7Pz++7du3M2r7u1tUW7fb63amur0dTUVdC2spABCdi2M31vu2SHt+5zqvB9qvZGQ3gHuhJh87lU\nABIAQ9IKfq+hpjfnkXLjOew/nsPi4Hnsv2Kfw3wXAiUdbZ7PxIkTEQ6HsW3bNmiahqVLl2Lq1KmD\ncSgFUWTzNHVE0/3akDXz/m3NvP6pwlhMm+AdnCecZnP2eRMRUfGUrPJeu3Ytfve732H79u1QVRWL\nFy/GjBkzMHHiRJx22mm4+eabce211wIAzjzzTOy3336lOpR+UxUF0IHOqGvaU1UDdBWQzHu4fQhA\nldOn0y/7ENet5nb2eRMRURGVLLwPP/zwbpcNPe644zBv3rxSvX1RqbIV3rH0oDQoGoSuOn3fighA\nkdN98j7Fh6iumE0bDG8iIiqiQWk2H2pUK5S7oq7R5opZeUtOePuhSunwViU1fZ+3zGZzIiIqHoZ3\nAXyKGcrhuN3nLZzK2x6spghvs7kqqxCaFeYyK28iIioehncB/NatYl0xK7xlA5IkzD5v2A/5PM3m\nqqxA6NbvisaVxYiIqGgY3gXwWfeX68KqoJ3bxNzh7Tebyi2qpMIwrN9lnUuTEBFR0TC8C1Dh91k/\nmRHszGvuCm9J+Jy+cQBmFW5V3pJVeXdGk7j/2dXY1hgekOMmIqLhieFdAL/PCm/JmkdNsSpwwzXj\nm65mNJur6cpcMdf0/ufyrVi1sRn3Pbt6AI6aiIiGK4Z3ARTJOk2SgCSlK293szkMNavZ3A53STYr\nb3s98GSKA9iIiKjvGN4FUOxbwCSByqAPvoDVg+2qvCXd22yuuprNoegw2OlNRERFMmQWJhlMslV5\nS5JAZciHsGrAACB0BYmNR0EZ2QRZqvbcKqZIKgAZQpedypuIiKgYWHkXIN1sbuDEr4yHL2D1fRsq\njLbxSG35CoRhB7b9GsXZxu7zdkjSwBw4ERENSwzvAshWEP/7qZNx5tcmweczk9i5jxuArouMZnPV\n2UbKvM+bVTgREfUDw7sAduU9fmwIkiRB8dmVtyu8hchoNndV3jL7vImIqHgY3gWQrSVBDWGGtqxa\no8Vdo811XXjmNreb0IWuAIoGwzDSO2SzORER9QPDuwB2Fa1b4S1Z93m7m80NQzgD2wCkg1xXIUlA\nyuDiJEREVBwcbV4AO5S/6NyGz9o3A0rKfMJwVd6GAclVUctOs7n537iWXguceuetj3dgQm0lDti7\nZrAPhYioLDC8C2D3eS/e+joAQLZWGfMMWDPSk7CYr/Fuc+fqe3AELhqQ4x1OYgkNjy7cAAD4yy9m\nDPLREBGVBzabF0B29WUDgJCyJ2nRDYHHXql3frfDW1LNKj2hJyBghntnJIkHnlsDg6POe6TpRs8b\nERHtYRjeBVAk72kSMMy7vQxvn/fGhnbXa8zntN2T0tsgXZl/UN+Enc2REh0xERENZwzvAmSGNwAr\nuNN93PGk7unztkebG51jobWMN3+Gd05znfeP9YhniIgoG8O7ALKsZD+oe4cLRGIpz+8KXK+xKnQD\n3hHnDO+esWeBiCgbw7sAco7KWxgZ/eAwB1c5r5HdK45Z94lL3srbYHj3iOdoePvX+t247I6l2N0a\nHexDIRpSGN4FyN9s7hV2Vd+K69TaQS/YbN5rDO/h7c8vfQLdEFi2eudgHwrRkMLwLkDmaHMATjXt\n5g7j5vZk+glhbtssbYLkT1cYKY6k7hFH5A9v/PMS9Q3DuwC5Ku9RVaFuX7P4vW3pX6yg362uQ/Co\nt5yHUymGd08Y3kRE2RjeBcjV5z1+VBWqQj4AQCjQw1w3OZrYAVbehWCzORFRNoZ3AZQczeaqrDrL\nfFZX+Lp9vcjRxA4AyZSe83FKY3jvGbhWD1HvMLwLIOf4ZnEv/1kdyg7vQyeNxrUXHoWvHTouu/KW\nNQACb3e8jHe2vwcAWPDOFsxd8ElRj3s4YHYTEWVjeBcgksq+jcW9/Gd1hT/r+ZHVfhy272jzOeE9\nzZI/DskfxxfJT/HEp88CAJ5ftgUr1u0q8pEPnLWbW7BibfGPn5U3EVE2hncBJo3YBwDw5VGTncfM\nZnPz5xGV6fAWmlmFj6kYCQCQ5exmcykQg+RPrzJmrxMOwGmKz2f+ss34+LPmPnyK0rrrHx9j7kvF\nbznggDUiomwM7wJU+6vwwIw7cOa+pzqPqa5Z1/yq7Axei6+ZisTGo3DUhP0BABUBNavZXArEIAVi\nzu+NkXQYdxdWndEkXnznc9z7zOr+faAS6unio7fKObxffGeLs+IZEdFAYnj3guIKbFVWPfNu71NX\nBQCoCYzAdbPOwKTx1QCAiqAv655wSUlB8qfD+4GP/+KsEa7p3YR3JJn3uXKR1Io7gr6cm83nL9uC\ntz7eMdiHMaSV8bUZUVljePeCu59blVQ4y2ZIwMGTRgEAamtCzs8AUBlSs/q8IQlP5d2aaIW692YA\ngN7N7WPhaCrvc+Wi2CPoyzm8iYgGC8O7F7Iqbye7JZxxwpfwzan74rJvHOp5TWXQlzUPOmTDCe+v\n7zPd3F9tAyBrWZV3Uk9i4ZZX0Z7oQGe0/CvvRLHD23U6/ufvH6KxPZZ/40HCC4y+4y1iRH3D8O4F\n9/3e7j5vSQJURcbZ0/ZH7UjvzGuVOZrNIRmQ/HEoIoBzDzwLB4a+AknVIPnj0DIq73/Uv4CXtryC\nFzctQke4/MM7WeRZ49x93vUN7Zj32sai7r8YONlO37HZnKhvGN69oHbT551PZUjN7vOWDEi+JFQj\nCAAwdOt5SUBzVXHtiQ6s2LnS+b0jo897xY6VWLBpUS8/RWkVu/IWGVVtOS7mknnBRURUagzvXvBU\n3pKCQtK7MuiDENnN5lBSkI0AAEDT0o+7+7zvWzU3/RJJdgashQLm/h7f8DQWbX0dulE+M7UVu887\nM6zLsYk6VeRBekREPWF490L2aHMzSLrrtzNvFcucpCUBSQIk3QzvlDUOTZIM6Fafd2ckieZYi/Oa\nqBZzKu+qjBndolr59AMnSthsnuv3cqAxvPuNfd9EvcPw7gXPaHO5h8VILLIsZd/n7TMnaJE0c3IX\n3S5WJQOaYQbBtX98C7rQceDIAwAAsVQMHZEEAMCvKp77qbuS4d5/mCJyH0vxR5tn/l4e4a27Dox9\n3n0nCup8IqJMDO9eUFyBrcqq606xHsoG4X1e8pshbM/GJgzreUkgkdTxxqrt0GWzyq5UK+BX/Ihq\nMcQTZjDqhkBcT8/QFklF+vyZisHdtF30Pu/MyrtMwlvT0sfR3b35ROXglZUNWL+1bbAPg4qI4d0L\n7nW9PQPWemzyywhvnxnMwqq8Dd16Xjbw7Fub8bfFn0JSzI5wvxxEhRpCTIs5FZ4hBLqS6cAO55h7\nfSC5w9tdeacMDXd/+Ces2LEy18sKkt1s3uddFZW72uaANSpniZSOp17biN8/uWqwD4WKiOHdC1kD\n1iw9ZXfdqFDOx42kWXk74S0Z2N5kNoFLqtkRHrDCO6rFnYFRhiEQTrnDe5Arb1d4ufu8t3Y24LP2\nLXh8w9N933eZjjZ3BzYHrFE5K5fWKiouhncvSK5RNYprkpae3Pz94zC+60QkPj3G83gqYTbD233e\nkiTgXApY06X65QBCaghxLY6UZm5oCOFpKh/sZnMtT+WtGVquzXsl84unXAasuQepsc+7H6w/Z+bY\nBiLqHsO7j3yyAvf0qN0J+lVMCnwZRkcthKv/OxaVYRgCuqvytkmqGXw+KYAKXxACAklh9pUbRmaz\n+WBX3q4+by0d3rmWUu2tzLDOvO97sLgvWDjavP/K5aJsOOK5HZ4Y3n2UOT1qTxTFOtVGeluR8iMS\nT0HX0n3e6RdYlbcUQIVaYT1vPmYIIJxKjzAPJ7NDMvMfbCKl4911u5zqvZjczebJZPrnLtcxLnrv\niz7tO/N7Rx/gLyJNN7BkZQOice+88u7AZp93/7Fpt3TKpauJiovh3UfmwiSmQu5R9dnh7VqkROgq\nWjsTcPLUGm0OpPu8VRFASA1ab2pW45l93pnN5l/s7sIPf7cUb3603XnsuTc346EFn2D+si0FfT63\ndVta8doH2/I+7xlt7ro4CLtuYVv4Xu/fF8jRbD7AX0TPv7UZT762EX9fUu953N1Uzmbz/mN4l065\ntFZRcTG8+6jQ+7xtimIlvHuFMV3BLY+uRDhid3obTsVsh7cCPypUc8CbZFXjhiE8TdI7Irs8s6wt\nX7sLAPDU6585jzU0dgEANu3o7NVxA8Cd8z7C35fU521+y9fn7b7/3JD7tiJa1gxrA/w9tHFbBwCg\ntTPheZwD1orD/nOyabd0mN3DE8O7j1RZ6dWiCnblLQz7vxKc028FuiS5m83NKlsRfgTUgPVYesBa\nXDPv8z669itoT3Rgbct656V2FSP7EqhvMwPcp5qj4/vTbJ6vOvKMNk+6wtvVIiAUb/gV/J5Z93kP\nbFC2dZnHPbI64Hnc22wuEI2nsO7z1gE9tuGEAVM6bNUYnhjefaRIhU2PalNVO6itjQ1X5S7Sk7TY\n7CpbNgLpJnopfatYzArv0yedAgB4a9sK57VOv/A+a3DvqofwcdNa+K33T/ajSszXt5tvkhZ35d3X\n8M5s8hvoUcntYfO4R1T4PY+nXIP0NM3AnfM+xp1PfYT6hvYBPb6hzv6nM9AXZW5CCCz9cBt2tQ7u\nfAmlwlaN4Ynh3UeqrOLkoycAAA760qgCtvc2m8vCHd7Wn8E1YE3yJyAMCbLwwWc10UtyepKWmBaD\nLBQ89sIuHDhyf2xo24hdkd3m7uzAC5lBMn/Ty/D5zPdI9WPu8XwDX9yjzd39v+5BdYbct+VMM9/S\nEAKabmTNvFYq9mfOfD8tY5KWLTvN7oimMlxvvJw5zeaD2POweWcnHnulHr+a++7gHUQJsfIenhje\nfeSTFXzntC/jjh9NwWH7ju5xe6fytprNZeSqvN3hHYNIhqDrrv512Wo2N4CYFofQfdi8vRMnjDfv\nH/+0bZP5vN1vnjJHqTdGmyGpZgWZ7EezuZ5jGlAhhGeeb/e0oRHXKHhD6Wt4Zw9Yu/z3b2DO3z7o\n0/56w90FkDkoTcszYI0LbPRNb6vD9zc04sEX1xWlqozEzC6q4VqgsvIenhjefaTKKmRJwtiRuWdP\ny9o+Y7S5e7S63Q9uN5srqg7Jn4RIhKDpBnyKtYqY5K6844BuTtEaUioBAM+/vRHN7bF0FaOkB4nF\nVXOFMvfgqriW7hMvRGbl3RZvx/ee+ylWtb0PqGY4u0MtYbgCW+1bs3n2gDXzd7vSLaXWrvT88cmM\nFgv3eXT/LGekdyyhIZbo/2Q1A03TDWzd1TVg79fb6vCP89fivU92Y3cBTd2vf7itV5/FMATunPeR\n526NoYyV9/DE8O4j91SphVCt0eb2JC2q5FrW0x6wZjWLT5xo7lskQkhpRlazOWCGt6GZj8swt4+m\n4nhx+efpK20lHRqblXcANenp835iwzO4d9VD+KhpbUGfQc+oPjd1fI6ElsCyliUIHvkG4Is74W0I\nA5qhQYHVV9zHyjuzz3sgFwGx108Huq+8NU/l7Q3vK+5+C1fc/VaJjrB0HlrwCW55dOWA9eH3tTp0\n5k/Io7E9hsdfqcctj+afXz+ztaSxPYZ1W1rx10Wf9umYyg2ze3gqaXjffvvtuOCCC3DhhRdi9erV\nnudmzJiBiy66CLNnz8bs2bOxe/fuUh5K0Vz+le/hrP1metb2LoRdeUtWda1KKn71PWu61IxmcyVo\n9puKRAU03Ug3m9vN6rIBXegwUnZ4p5/fvKPTudIWcgp1FWMBAElEoY773FMl2iPUN1rN7T3pbrIH\nSTGgjGx0giypm8EXRJV1Avp2q1jSSHpaEIq95Gh3uqLp982cRU3zDFhzDTTsY7N5S6wVN7xzOza0\nbuzbDors/Q2NAFBQZVsMfa0OMy8oMxXy/0vmn2y4dX2w8h6eShbe//rXv7B161bMmzcPt912G267\n7basbebOnYvHHnsMjz32GMaNG1eqQymqI2sPxxn7fb3Xr3Oaza3wliDjgL1roMiS0w/ujDb3m1+Y\nduWdsjPErrytMBO6VZFb64VLio4dzRGzipEMQNYxOjAKFx30LfP5jACt9JnN7YVOr6plfAnYI95t\nysgmZxR2QrdmiDMqrffuW+W9UnseoWNegz20qbezRdW3fYZH1j2BVB/mWe+KuirvjLECqTxzm/e1\ngnz1izfRlmjH3DWPFbS9EAJPLKnHui2lvT2tKuTreaMi6Gu+9HSPfUE5LHX765DHPu/hqWThvWLF\nCpx66qkAgAMOOAAdHR0Ih8M9vGr4UuzR5s7tZVbftyJD2KPN7T5t1aq8k0GkdAML3ramFrUGrNnL\nhcIKbwjF83wkrjkBH1KDOGj0gZ7nbVU+c0BbOJk7vNvi7Xh03ZOQrIuJzCrHvtf8lJFnw4hXQK5u\ncyrUlNXfrYgghC47y6D2VhhmOMk1zX16/b2rHsL7uz/CxwV2Dbh1uirvzJDwDNJzN6FrfWz+tVpy\nNFHYRcb2pghe/WAb7pz3UZ/erzvukfWZF2yl4q4OOyNJrN3ckndbz/H11I3ShzJ6uGXdnlh5G0Lg\nd3//EP9c8flgH0rJ9G6asF5obm7GYYcd5vw+evRoNDU1oaqqynnspptuwvbt23HMMcfg2muvzeov\ndBs1qgKq2rum6p7U1lYXdX/dGdlsNT/azeaKitraavhUGYmUt887FJKBlFlZ+/wqWtpTwCjXJC5W\neAvdrIpGjrA+hxXOmiGchU1GVY/AXnXWrWzW/u3PXREIAl1AzIjmPBfLPnkbK3evQuAIGfH3T0f1\niJBnu+QXZrhVh6ogkgHIwSg0w0BtbTVi7eaAMkX2QST9gJrM+R7vbVuFUcEafHns/p7Ho/EUKoLp\nqk8Zux1GR61nm978/Xyh3v+9U64vPUOSPK/3B9LHJrv6XYMVfmc7dytBT+8dCJj/FDVDK+g4O+Lp\nC7Fi/3/c1pluUQm5Pk8pqT7FeZ9fPLQEja1R3Hftydhv75qsbSOx9EVVVXXQeV2u44y7rrnyfY6a\n1phnm5he+N9tKGgKpy+cC/k8w+Ezh6NJfNrQjk8b2nHJN78y4O8/IP9mSv4Olsz7ZK+88kpMmzYN\nNTU1uOKKK7B48WLMmjUr7+vb2orb91ZbW42mpoEbTdvVZX1BWAEqdKCpqQuyLOWYpMX6YhYyOrvi\nUCUFCddrneZva8BaW6s5ktsO/65Iup9Y0hR0tVnPWzO0NTZ2QpIkdMbMintXuAlf7GyCX/Z5+vI7\nwzFnv1IwjJaWCJpC5nsmkjpefOdTqOOARFQ4rQApI4mmpi7s6jAHOukpCdD8kIKRrPNtCAN3vvMQ\nAOCBGXc4j2/Y2oY7nlyF807e32yokAClphkpyfBML9ubv19LR5dn+22NYUgSMKG2yrPdZ9s7cO/T\nH+Pq849EY0u6RSIWT3le3+EKuIireb2tPepsF0+mq+jujvWtxmVY9Nkbzu/23ydTMqXj3U9247iD\n69DWnv73UKz/j1es24VdLVEctl/61sfWtuiA/DuJu85vo9XPvnFLC6p82Y2Dja576Ztawmiq9uf9\n99zSkm7ty/c5OjLOZVNzz68ZSlpb0/8f9/R5Bvp7sVQiroWEBvrzFPsc5rsQKFmzeV1dHZqb002d\njY2NqK1NV05nn302xowZA1VVMX36dNTX1+fazbDh3ELk6vMGkNHnbQ1YU+1FjmVougG/Yo3Ylg0E\nfIrTbG73ecPwNpvHk5qzTYUagk/2eZ5/7q3NeGfNTqfPOqEncd1bv8b1Cx/CZ9s7nGN2z58uBSOe\npuKuWNJpAZAMn3MsQtagG4bTbA5dgdB8kBQdCc3bdJ7Qc98+ttIaLLVw5RanA1JSNchVfR/5HE15\nJ0/59V/+hRsf/lfWds8s/QyRuIZnlm5yms1HVPg8zea6YeCL3el/nO4+b/e98O6R/d01Xc5bu8Dz\ne0TLfaG6YPnneHThBjz56saSNO3OXfAJFiz/HM0d6XM1UPO25+qXzddk7668e1qOtZAxEpnvM9xW\n4ervx+mIJPHBp03FOZgBMtz+hrmULLynTp2KxYsXAwDWrVuHuro6p8m8q6sLl156KZJJ88t85cqV\nOPDAA0t1KGUhff+vNe847D5vKV1NWuEtK1Z1bshIaQYCavo+74BPTt8CZjWb6xrM6t0K51hCd6rz\nkBqCIiuQhAzJev6fK7bi4X+uzxpwFg5twd3/SPehulcFU0Y1oiWRHhxlGMK5QJCFz6m8JUWDpgkk\nrNHmwlAgUubFR1vcezUaTaXf372wiv1FLqvmY0I3L07kEd5+UPMiQcsK5lw6k7mvhA1hYHc0/cVk\nT6aj6Qa6oklUhXwI+BVPiK1YuxtrXQPFtDyD19yz2el5phAzRPbjLbHcg9B2tZih/vmu0t7j3tKR\n/ruUMrzdrXG5Lm7yjST3hHcPo80LGayV+d7D7YvffQ76MjPh//z9Qzzw/JohNfVvrgmlhpuShfdX\nv/pVHHbYYbjwwgsxZ84c3HTTTXjuueewZMkSVFdXY/r06c5tZKNHj+62yXw4kK0Ba3a/tV15T6yt\nAmA1nctWVW6FNwwFKV0gqKbv8w74FSek7cldNN2AJBSn2Tye0JyAt5cTlaB41wuHQEJPoNJeKxyA\n0FTPVbq78lZrt+PvDQ8imdJR39BuTlpih7er8oaiIaUbSOr2iHgZ0Mzwrm/x3pIW1dKh61772/4y\ntS8OjC6zGVcOeQc8aprAw2sfw8+W3YRIKuoJccMQePHtLc5kOJ3J3IG3cMur+M27v3fudbfvCtB0\nga5oCtUVPvhUb3hv3tHh2UfKM2DNtba5PUJd0pHQcg9Ei2vZrQ/5Rv/b/w/phijpl1OLq0uglMud\nunMkZ3jnCdGwK7x7Or5CgjgrvHvY519eXo///r/lPe63XLg/X19Gntu3Cw6lqX/zXSwPJyXt877u\nuus8vx988MHOzxdffDEuvvjiUr59WXEqbzugrdHml5xxMPbfewdeiSsw7AFp9n+FDE0zYOj23Oe6\n2Wwup8MdsKoj4Qp1pG/NspcTlYXqHW2uaBAQ2K9mknO/t4hXpudgR+4Q+cvL6/Gv9Y046/9NgqRo\nELoC3YCn8tZ1A0nDWr5UV2AkzGOYt+kZHD/hSAStVdJirvDuSHRiZMAcnGR/v8jWoDsRr4DQFUhB\n7/GkdANrms1j//mymwEAPzjsIhwz7ij8a/1uzH97C0LHCEABOhPp4HdXa+/sMCfvWLV7DY6qPdwJ\n70RKRziWwoSxlYgndU94B/2utdxlAy0j/gUlPgJSMIIW3Q/AHHyXTBmAZCB45FuYv6kT3z3sW1nn\nM7P1AzAHreVi37FgGKLHirM/8lXeiZSOrmgSY2sKm1WwJ+4gyZWx+YI3Ek+fn54uYgoZaZ35Pj0F\n/turd1rbGVDk8p/nyhPehkAP89rkNZRaJIbSsfZV+f+fN0yMqQlaP3mbzasr/Pi3KftClc1wkkc2\nIpwKQ7Kq8ZRuIJGy/keUzD5vJ9ytyjulGea93q5wloLm1fLY0Bjzd6E4zeZAetWyoBLElV/5sfmg\nrHtmrAqnIhgd9C668q/1Zn/0xoYOc1CcrkI3BISRWXlbzea6DL3xSxBJM7Cjrv5cd3gv3rrUqdad\nudntixFdhYhXWp/JfZtQdoDtipjHZ37BC2cZ1Q5X5e2e6tReS317s9msbs+E12atJmZW3rInxGLW\nQLTbLjsBoTFtSFR/Dv8Bq+GbsAmrxItOU3hS0yH545D8CWzsyD0NbVw3gzKoBHDKPicCQM570qOp\nmPP31Q0BrciVhbs5tdm1drm7JeGOJ1bh539a4Zl5rj/0HirCfBcoUdd0sz1V3oWFd+ZtgIV98Wd2\nKQgh8PFnzWU3Ha773PYn1IbSLWdsNqeiGVUdwG2XneBUywq8k18okgJJ1RD48ofYEdllzaomIaUZ\nSKYEhCFDkg34fa6QFq5lPo2McA5EASFhTMhscpZyVN4AsGZjJ3738GdQjRAgG051J4RAOBVBlTWR\nS6akpluVt2p++en2RDEaNF044W1oMiBk6O3mYEU7oAHvILKPm9bipS3mGIms6V11FUas0hz17k+/\nRtMMyJL3f2FP8Lmmh+1IdDrv51621J4AJ2m9zl533V6UpLrSD58qw3AtwBK3ngv61Zwz7dmfMakZ\nkHxmOLfEWz2f3WZX3idNnIqJVXtnfwaYs9XNee9ObAq+CkBYK6sV98vJHUSeytsVjvZ88u5m6/5w\nh0GuUMn3GfU83RQ5t+1L5a27jyv//jOPb/naXbj3mdX466INPb7nQHJ/hP5c8w2lanYoXWj0FcN7\nAO01phIHpk6D3joOU+r+n+e5zBHGftkHVTFHmyeSulllywZURXaazYWr2VwYCiRfylkARA5GIeuh\n9LzoIqMyt5qk7XlzzIsD3ak8E3oSmqGhUq1AcrN5n6Q9hzoAJDTdDEddhaYJT5/3X15ej664GZS6\nZq+mZr7WDnXAW3kDwEeNa8xNnT5vb+UNAFIo3XSe0DSnYjyq1jzGlGEHp56ezAaAgMD7uz92nks/\nYU9ba430z2hTHFFhhjeQDji7sgr6FShq9rdh0khCNwxsbwxD8iec998VzZ4C2J7oJqQGnb9VKiPk\nP2hcjY5kJ8LyLsjVbVafd3Er77hrBTXPimnWZ3avsFasL3F3tZ85h735Prk/o2dq2j40my98dytW\nb2rJu02+VfIyZVbe9iDGTdtLv2hOb3gGBvbjNoVi/z9XSkPpQqOvGN4D7IpZU3H1cT/AtMO+1O12\nqqzCp5qVdyJlhndNtYq6kaH0wDO72Vw3zIFhAEJfXQqoCUj+BJRUFYQQ+OPzaxCNCUiygNPsbM8X\n7p5iVU734dn93RVqJTsCVIYAACAASURBVPTmCdA7R8OADsBuEk5Bks3Qjic1T5/3Z9s6sGqTGVS6\nZlXyVmVu94UDQDSjv7cl3obGaHO6/9OpvBWIlNns7p5mNZyMQkDgyNrD8c39Z5rnwqpaw7GU83q9\nrQ5CAG82LEdjW8QTRPY99kKyBt9l3F49wmo2B9Jf1vGEBglAwK84I+LdknoKTy/dhKde/wzwpZug\nd4R3ZW1rV95BNQjVuqVPM7zhvXLXh87PytjtVp934V9OWzq+wJ/XPo5wMpJ3tHE8zxzg9mduaDSv\n8tQJG/HytpcKfu/u9NRsnm+kuztce2w2z9hvNK7h6Tc24Z6nP05v002fd+b+3ecvc8rcZmtA11in\ni6w86D20cBS8nyE09Vyxu5XKEcN7gPl9Cg6eNKrb2eQAwCer8Cmy1WyuQ5FVCDkJQ04BkvWl4fR5\n604VDgDKSPP+eilZiVhCw/ufNmXdCy75zdCwQ1FYfeZ25b0zYgbNCL81QYDzeiu8EXFeH4lrnsob\nSFfYuqZ4Xp9wVd6fN5mVynXH/Bhn7GtOpdueaE9/Qcqu+9l17/EDQJc1rWuVr8IJPrvyjsRS6dHq\nsSroLXthV2wXfvXss3jh7S3pE23tLwnzizfzy7raVXl/vH0zVu/YjFhSRzCgQJYkyEqu8E5imTWo\nya68ge7DO6QE0pV3RrN5S7wN1b4qyEKFXNmRNWBtV2Q33tuZf33zf9TPx6rG1bjx5b/i9sdzbxfP\n009rn4+GRnNMgG/CJqxu/xDhWApPLKnvVxO6O0dyZUrmMqw276IwvWs2D8ey++u7azbPvIBwH1Pm\nc01Wd8OoEYFuj2mgGT3cklfwfoZQNbsn9HkP2Axr1DuqYt5fHEtqSGoGAlAQTnXhXflvgGwu4iKE\nq9lcl5wFFeRqs0lQSoUQtkfmWkHv+9IGpD4/DFLAXrnMHDls6GZzvH070spdq8ztIxMAtDmVM2Qd\nMFTEpU4oMEeoR42Uq/K2mrqtCwwtZVW2OZrN12zdBXUsMMJf5dzSFtPi6S8J2T52NT0JnSss7TnZ\nK32V8Cne4AvHNE+fubbjAKhjd0Ie0YpVG9OTB9n3w8eMMO54/37ElNEA9nKe19ROrK94BsrY/fBk\nwyLz/RJnOyPOJVflbUSrIFeEkdCTqAgoiCU0p88bAHZEssM77qq884V3OBnBmNAoSMlKdIR2Q0fK\nEzi3vncnAKAmMAL71UxCwJ7Uxz4uawBdomIbNm34ctYxAN5mczc7HKMJLf33ADD3pbVYs6kNmiHw\nvZkH5XxtT/L1ecuSBEOIvCuCefq8ezlgzT1ffa73BrxVW+b+3bPmuS/0hBDOQL5yC7nM0eZ9NZQC\nkc3mNGh8soqAT0VXxPyySfc3Cydw7C/7ZMqAcC2bKVeYVZIwZIStLys7PNW6bZBrmiFb4W3Ezfu8\nDatvWlENGMLA6uZPEMIIvPCKNWGIYc+/bo149pnNqCJe4am81boGyKN2Oc3Qeko2mxFzhLd7xLs7\nvJ2Kxj2TnD0Lnavytu9Dr/RVOLPI2f3F4VjK2b/QVAgtPdGNh7WNhhS2djag0f+x5+k2YxeSUgT+\n/dMLm8STOkKBdDcBACTWHwe9dbzzGYP281blXaFU5q68dbvPO2Teiw/vrWIpQ0Ncj6PaV4VqqRaS\nBBiBzpxNyvd/NBd3rLwPQgi8/uE2ayY2gdZ4m3ksqubchZApka/Z3AooTRee127aaf5/YfSjedId\nJO4+b7v1J6nluaDoplk7U+aXeFeOkfLdNptrmeGt53zOfftavhYDIQSefXOTZxbDgeAZbd6Ppu+B\nWqSmGPaE+7wZ3mXi9Emn4OBRBzr3OvtkFUG/4vzDE3L6S8eerGREyAy8aELzNM/KlWZ4G7qcbtbU\nXaOiJQEpEDOraWsCFfteclk2oBk6UkYKWjQEZ35Su9ncqnxl64vciFeiPZxIr3AGwDfhMwgrZFMp\nCaOqA1AlMzyT1rSpumFO8iKEOUNb0BXedsUl5HSfd2azPwBENKvyViucCxk7+CKxFKCmK3e4lk1N\nnwcjPSFOPkp2c3IslUDQnx5dDwBC8zvvsaO1AxV2ePviEJoPtYFx6Eh2eia+AbwD1pZ+YIb79tb0\ngCd7lrsqfyWqYN72JwKdeQcP7Yo2YmP7Jjz+Sj2WvN+AjmSXZzIcuTJ7MFVcS2Bly/KsVeeAdEBp\nugHZdZ99NGVdlAT7vmSonmcglWKHd54Q9FbehQ9YE0KgM9q7ZvPsyts1sM89IY/r4iffRcfnu7rw\nzxVbcftj+bs4SsGd17kGBvbEPb/AUFGs1oZyxvAuE//fAWfgJ0dfZt0iBqiy2WxuS0npL2A7qGsq\nrfCOa5B82TN1CUN2+vjcfeKSrEMKRK0mc8nZFgBk1XACUNcl1768fd72hCkiXoH2cDIdrjCb0u3K\nW+gKVEXGiKDZPG/fLhWOpszPofmh6cKpvONaHAnNACCQ8rUDwgxGu9nefTucHUqVvgrzVjtIzoC4\ncCzlDG4TKX+Oyl3Af9BKz2fPJQVr/vdPj4HeYYYnanYh4LcvqlyD6qxz8MTrG8zKXElBCkZhRKsw\n2mfeKpdZfcdc4b16o1khN7anJ5SxZ56r9lVBMsygFJLebRW0rvlT5+ftHebAwUlV5gBJKZQ9Teyz\nG1/EB13L4Nvn06zn7PBKaYZ5+6HFvmjpzz3NIk+zuT1oMpmnP7uvfd66Yc6alykz4LuvvHM3m7tb\nLqJGFz5sXJ31Pok8XROl1t8Ba3Z4r9ncUvKpeYuluwuw7vx10QZnEp5yx/AuM4p137JPVhH0eatl\nN2HIGGmFdyyhQW8dl7Uvs/K2vmxEOoilQBSSqkEkXTNluSpbu8/VHinuft4OTykQMydesSZnCfpV\nnDzCmkFMNiAk3ZqaVIJPlVFTYb5XJGmGVWs4BikQgxGvQDJleJrNkykdck0z9EA7KhP7mHO4Z1T+\nABC1Ku8qfyUkSYJPVhFJxLHgnS1WeNvN5n4AMoQhpcPfl4AywgxLvWW8s09ZT48U3m+vaucWPpEM\nQsTMufn9B6xG20irerIH1bmqe8jmjGxyVbvZzN01GmN85t9nc8fnnr+RHd6bGiLOHPbuPm+7X7/K\nX+lZtz39hZT9ZRxOpPvZd0fN/v0vjzBnN5QrO7NGnO+KmhPb5Ap2O7xSuuFtclfSa8c/99YmvPbB\ntqzX9iTfaPN05e0Nu2ff3ISXln+OsNTsdH/0ps9b13NX3lpGuOVbqx0AYnmazd2tBLvHvYSH1z6O\n19au97z2b1/8H/wHZy+GUyhDCM/FQ8Gv62cVav89GhrD+M2j72PTjg7c98xqRON9v3ArplxjI/py\nwZJM6Xjzox34y8vre964DDC8y4w9baoqKZ7KO4sho6bKbPKOJjSkPj8MX459A0YsPamK0CWn8naW\nEQUgWc3u9qxn9v4As9nZvlXJXXlnjVZXXCPMAVQEVewd2MfaRoMOzemHVhUZFX4zFCNWsOzobIYk\nCYh4BZKajpCSEd6VZr9gZXQ/81hzNJt3psyFEsYEzYrYp/iwszWM55dtMf/B+lyVN2BeaNjH75rn\nXa4I43jpAnNbpB8/cOJIp5lbaH7rIsDU5WswH5fTg+Ls1gF17834rKUBcrV5cWCER2Kczzw3G9q8\nM63t6mqF0GX88dkNCPnM/bv7vO1b9qp9Va7WAyNdfarZYbR5d5vzc7u1GEyNOhpGIgg51JX1ZaZI\n9oWZAXWfDVDqvjB/l1zN5prhad2RrM8djafw0vKt+PuSeuyM7Ma8T5/POV97Lkae+7ztqYTXbmnF\nR5+lBxf+c8VWvLBuGT6vfhnq3uY8+T32eXtmFzN6rLx1XXQ72txTeWu5K2+7p+mJ1z51LpQMYaAj\n1QZlRO5FZwpx3zOr8V93vdVta4emG57lMM337t993plTwP7+yVX46LNmvPnR9l7vq9iefXMTfnTn\nm9jZ4p06ubtBh24bt7Wjrcv8/zVfyLd0xEs6HXFfMbzLjF15GxCe8E6PJbcYCmoqrfCOpwChoEau\n9TRf64bkDFjzfPFat4l5mrqtn9uqV6fvv3Y1J2eFp6x7Xl8RUOFTFXMOckWH4QlvCRX/P3vfGW9H\nVa/9TN/19H5OzknvIR0SEjpEulIFiShYLyI2BEQR9PpD5aJX5d5XQbHAtYAIypULWABpIXRIg5De\nc0pO3XXKej+sMmv2npOQkJAE5vlAOHvKXrNm9jzr356/wfTMWax0xyDt5EUKSdiOhxjTYM+5eRSY\nJjgdgxEYq+w2H3D7YGkmKkxqERuqESB33eR9z/k51OD4GZztI2F5FehIjxDu/mRMxylzRvgxascA\nsf34rgGLndIG8VhnODZGNZaFPvlpkZvgZSphKnG0pVqwrm+DiGl7xENPoQskT5vT8LwDuXXqoBTz\nlhdQPO6rsAXKmMqRmBc/EwCwTYqZ9xeY7CuJg2QroJhF9GT7AxKnQo42MQijeQN0Rt5xU4fteMg5\nebylPwa10idSYXnnfCL58Su348mtS/D0tucwHAghIg8j2DDD30f+/Cf3Bd3PWgO18DU2Fsfx8Pra\nnmFL1uRzOR4JlXYtzXrfXZ33cAlroXFuhYhEtqCG/b4RAReWkRvHlOJbv34Bn//RUwGyGS488Xah\nlogfcC/Du9Uudnd4aMlGAMCK9cFFUVAlL/yaO3uz+O7/vIyb734RQDjJ9/Tn8dWfPov/vPe1sm0H\nGxF5H2Lgcp8e8QJu85OSl+CyKR8RWeeEqEjFDWiqItxXhq4G4rfE8RPWvLxvkYsabznWy+uwY9vx\n5zUPBT4D4Mufqi4AAqjB2vKERevS4eo0EU3xyds0NCTMIHl35ujLl1reXiDmXXRcEVsnpCRhTvXd\nxUNuH+pitejpz+P7v30Zjh20qDXTptYwczcTT2rqwv51drbD3dXC+qYbgOohldBx2xePRW1lDBk7\nA0MxAaIGLG9DYeSt+AI1gfkC/GQ3x4DjemhPt8EhDr551xMAgK5cD4jiwcumEDM1DGR4jbwjXLfC\n8jZTUtzft7x5XH989RjUaC1iO8cga4X6qwfXw8vSmv2/L1+BL972NJ5bSePv/YX+wHGinaylwXE9\nPLVlCfr1jZClCbjl3ZcpAIoL64h/iYXGa10rUIqubA8Gi0P4zSNv4KofP4Wt3ZlhNbdLX7YD2SJW\n71oLrWETVOba91gI47W1PfjRH1/Dd38d7o4W51IdvLTzFQxk/UUst4qD3+1hlf009DYa/y+zvAsS\nebthbnP/M0X1hFUnt9flHqF9xe6M561d9HmRPQHDLYzeLrRS5aJDEKW6GYFF2zBW86ad9J70MC3/\nsORH3tt+1cbesm0HGxF5H2KQyduSyLs53YA5jTNgKiwm66mIWzoqkqZI7DE0FfUVPkm7riLI29ky\nDsUNkwFIwiEy2Uj//0bvWwDoAsHfTv/fHLNMxHJlyzwRM6DrKrW8VQdEcaEpPB6uIcXc5isHX8fD\n6/+B3iJzKRcSKNgutnXmoCoqNnX3omh7Qq5UfAdhMWtOiEYBLhxk+k1c87MleHNzHwaGXMhtTxW9\n6LvMAboA4W5zVodOHJal7hIYGvMU6P6POGNnYaksN0Amb7Yw8VAU4QNSQt6KXqTKdVDgekR0U4Pq\nghCC7Sx5jeTSKBRdZLJ+jTx/6QvL20j6fd+lmDe3vFNGCiopDy2IVquOKch7xY4NAIA7HlwJQgj6\nCiWlSxqXf6WWt0tCrErNRUXSxMBQEUosCzVGX3KqomJd/4aApekRD//x4m345Yrf4cnXaDLQ+m0D\nw8a8S8l73dYB/PjV22GOXClkfVHiiXpdcq/L4C9xo/0N3Lf+TxhM+fFM/rIujXFvJstgtKwXf3O8\n2rUcAwU/L6DUba5YGcSP/Jv/5ap/H4ekKgO5MmRf8HZ6cpcuSDj25DbfuGMA/3X/soDrfTjyPpRy\nuEuH+HZi3p0lLU7DKjhMQyv77FBBRN6HGHi3MZd4Abd5OkGJQybvmKkFpBh1XUVrbYX4m3gaeofY\nKp9ocDtHiAYn/Bz+viGPQpjbHIA1eSn9n5KYNz1GBzQXRHVE85WYqSNp+eP86/q/+brmtonnV+7E\nt3/zIlxbw2CBveTYGF23xDvAPuelal2d0o+LaDSBTKUdxVy1ECBcriInn58vWlzPg6XSfXVDJu8M\nYiodu6gVB+ApLC9AsUXSXqAcDzSWbjHCdlxPiKcomgvXI9gyRInMy6XYi1AR96efuXeHbE7eKSGB\nC9XzrT5meadNSu6EoKScboiGDYgKkqXPhpyYtq2/F45EzqpnsnI6D6ahwnY9IfIiQzc8NFbHA+1n\nTxt5MmY3TAfgl8ABtAFNxsliTd86keCn6wrk05JhyAYA1m4rr4tWNV8NcHcQOvkshGEnt4ltnHxl\nK01uHSvvs7p3LX6+7C68YD8ItXon9Kb1cFwPHgsDFG0XetPGkkF6tIwS/n0EEBDuKRsv8fZIzm8n\nbC1n4e+N5X3lfzyOl1d34bkV5Tr85eMoP9ef1/zfbtX+9hV7Gnep5e0GLO/wY3mf8mSMl5mW73co\nl5lF5H2IgVvepJS848wFzcmbqIgZQfI2NEVYhAAATxUPKIXix39RalmHrDBD3OoyFATd5lTpTYei\nuVAU1gwFQNzUUBEP9oC2XZ6lreI1Fssjju5b1oxcbVt+80iWM3P9J9W0fz2uCkUliM/5BxQzD6J4\nAVc37bxGaDy9pDOb6xKYjFw1g373uv4NsD0HMS3BxufPnY0C8k4BLmxhvYfNkaHQc7oegcoFDVUX\nRdtFX44nDkpa2GyBMsAWXZt6ekA8BU7RD4koiivkTLmLO2kkqTEqhwYA5NwMVI/OPfdCFIlv+e0c\npLHCZHYU8svno4KwzHvdEfK8A5LL18vTcyUTCpK8xpuPQU+IOZRlcLn17xFPJPFpqor/emCZf97d\nSHgOZotlOR+q7ore67sDf4nzygoS8/MBuFXtegSKlYU17Um83hOMsfMXf3+BHpdVemGNewVG+5so\n2g4efHo9rvrxU1i5cRfUip7AsUrA8vYTqoazvPNOAf/+3K349crf7/aa9qTbrVZ24U3mPQPefsxb\nnvdk3F+YD7eYKP246Nr4+6YncNeqe3Y7vr2F63n42h1LcO9j4W11AZQ6YgJW9HCaCNt76LuxtoL+\n/sLc6283Uc0jBKs29r6jxi97i4i8DzHwhDW3JObNLW9D5a5XD5apo7bSJ0VdV8vIuxQyAQXIhoQ8\nCnsgd/m7EjFK3rL1yWO0MVNDMhaU7LSJLb6DJxEpngHD8jCiIeWXAknhQSK3PWX/5vNAR2MaJ8xs\nDYxXYdnqPMnMMjV/PjSnrDOb43rCbc47hf34lTsAUPUzOugY7E1UCtT2CljeQ12w3mBN4FwyTGbN\nP7NsOx54YpMYe9HxROa9fJ9UaIDioj9TxKadg+jNDQKOibVbB0RCG1RPlCxx8t64Nc+6z0neBcVD\nkRRgEDZ+fq3En1SejY5CEqZdA5PF8hXNgc403Tlx0fmk280YEd4WbnnHtLjwLshKev1539LnBNfd\nnwsmj9VtwD82/QuATzAXnzyOnstxxaLW3jyOnch5W+QtkvGYkp6iEtF5T7a89Za1UONZPLL1kcDx\nolQupMd63inikefpPX3hra1Q48GMZyhyzFsi72Es78c2P4nOXDde3PkqVvS8OSxp7l6m1IM14SX8\nz9q7sXGAVkS83WzzjTv9+/R2Er5K8wGyUmfEMG/NviKTc9DVl8fGnYPIOTlsGiwvS1R3Y3kPN/4u\nFs/mW12XAEYe8SMfwWObnsTTr2/H7//xVuixpXjxjU6ahf9WePjmQCAi70MMgZh3wG1OicVQfOvN\nKnWba6qwfADfsm6o9gl+WPIOURIjw7jNOUbUVfqHqwomtlcFysc42cRMXciJcnDxEz6GuKVjYmsD\nHGLjcxeNRnUFHWdBTiJ2Zbc3V3BTkU4YaKyOB65HJOUxy7syYYpriM96TKjQnbuQkoHjEphsMdLX\n9Dh6cr2iZGt2zZHivM6OUXD7a2ATGy/upPrvXBa11G0OACZbbK3fPuhnzGsOirYrVMrkudVVHVA9\nPL9qJ2761QtQDBq394h0P1TXLxdi5L1lexH5okv34d4JVmGQGeT3UQFxNbjwJ7UvT4nZKRiIW5oY\nLzQbpk7H1S/FefkCSTdcn7wNej5L8cm74PrWZU/Od3trjLxL1dOMjlV4YM1DcDwXhAAT26tw1GRa\nG19winCJC7evDs72MaKigTeMCUPfUAGPv7JVkJDcjc4cuRJqRXfA8pZDQNLFis5hgYQzBqphH1zA\naH3t0gLDldzmEnlb4eQtJ/r9v9fuxLLulaH7DVceV3RtaHV+WODxzc8AKPdqEELws9d/jf9dG1yo\n9PTnWRc8EiDm4cgvb7t4+LmNWL2ZlmxmbT+GvCvfF3rMvoCX5xUdF//96p34/gs/KRM7KvXWBGr3\nh1ns8NACfw4c1xNVDH9a81f86onnsKnTv++7s8K3sERBfr/fDUTkfYhBVcOzzbmVoTGZUUV1ETM0\n1ErkPba1UsiE0pPQ40c2paXPJHeYbJmXan5Lx7O9yzaPb6kVC4yBrI3KlIXjjmgX23lTkpipiZcc\nh6uxlxnLJm+pTeDoFkqSd6+6V4xHdpvLMWthgbsa0gkDDdWJwPWUlsNVJM3A9ei1NN5cnaIuccfz\nhCeBqDbuXf0AvcaqMWhPjQheuEv3W9e3EaYSA8nR+f33y+eXzRG3vAH4Cxtmeedt+sKvTib8/TUD\niurhjU19gOJC0VwQx4TrefA8iJg4J29uUaowaRmT7DYXjVmkBZur0wx5Bp4QZ+cNxC1dkLeiOeLe\n9hcly7tAnzdV84TbnBOXqcZD3ea9WXo8IUy6Vy+KEkYK/9nryXK3ugKTkXPe4wsxQ1wDNCcQ81ZT\nvbjnzT8LBb8f3PMq7n70TSxdxWK3hpSAVbMT1sQXBQm6HvGrGmRIuQWDdjl5F11byMNyD4ilxv3K\nDtUT3gWZvLmGQSl6MoMwSAIN8ToAwxPgcKpyj21+MqDBz/NKZPL9xV9XYe2urVjWvRKPbHxMfH7/\nk2tx9zPPIT7zceitawLqdjIxquke4bnY3p3BH59Yi+/9lraslWV4d2a7Qse4L8ixDP9C0cP6Aerp\nKG3y44sJ2VjbtwHLi0/AGE1DII7r4anXt+FP/1obOIaHRGQJYPkdEZu6BErCf/YzuxGl6WFW/HCS\nvgcCEXkfYjihbSEA4JSO42GZ5daAcFUzy7u1jr4oJnVUY1RzhbAeAQh3LHe5A74rm26XasI7R8Dp\nbIPT2eZvl9zQJJeCN1SJRNHvuGVqJs4/kVoZU0ZS17HIqAZQLPjkHbf0gIAMjAIjW7pPU20Ccxpn\nYFrdJKzr34AhbTs7h3TxcsyaK615GtIJk1qBsopcSUb9rPH1qEun/HMxxboYUzVzXSL01wFfxtXQ\nDJhG8GfCSSTjZGEp/uKptT6FUpiaLITj66sXbQ95FhNorvGTDGO6IVnOvshMNu/QlzBbwMiWNyGA\n6rG+6l65d0K27ImniVp2wE+kKuboPbKYWA50h1U7EAzZQ1DsONA5GvYW1pVMc0SiD0+aMxCDxa5X\nJm/umvcGqJiOVrsNg3JrTql0atsQJVtVVYVlXfAYKTAvCl3EObAMSU9/1HI8ufVZPLLhn3A9V5RM\n9bA2nUqImI3sNlfCyrcUV7yMB/dgeXMPSEz1PUCK6olqD368O1gFNZ4RLm0ZWTuHQk7DB0efAcDv\nA5ArOPjt31eL/Rw3PKntjV1BFy8PXcge7KGcjd89/1TZsX99diNyBiVEo3VtoFc5J38lNgRr0guw\nJtM6/lK1uqyUUf+/ax8WvyEZL+54Bc9L/elLUXSL5fr/kuXN8asVv8OPX75dhMe4Vfzwhn/ihy//\nP2zxVkKv2waoDhzPw6/+7w08tGRjwAshpH+55e2RMg+kKpP3btrfdrPnbLgGPwcCEXkfYphcOwE/\nOf67mNVwBCyj/PYYnGBUDzFTQ1XKwq1XHI0vf5hm+fK4LQBBvgH3ouwelC1vosHeMBVeVs5WD24v\nrJyPKUnfhWxqBj588nh89zPzMGMctRZiElnlmfEbs6jbvLBsIYobJ4rtinR+njRycvvx9OsUjyXE\nlMfdrSnPCsubeNTytgzNT3aDH1fk1xC3NMwd7y88eDcxUzOggCa1aFKHXMI8Dbyvugw5NGCqscC2\ns0afiqq+2eLvQHvOgHyqi6JXBCEKWup8z0jcNH0vCCccx0Qmb9MsbE+lMe+Cr3QH1wAhCLjNdU3y\nTngl91+aJy7/6hQMxE0NMS3OzmvT5iu6DZe4UPKViO+a5hOo6kiWN53LJ1/y+8bLMW+e8ObsGAli\nGzDa30Rfwbcqq6sly3Dlb2BNfRquPgRNVaEqCoqk3PImqitCSfLcPrrxMVz1xNegN9FSL9cjwoPh\nZYOLq0DCWgi5Q/WEBSqTN/GYfKtr+6ED9jwljARkHf3BnA3Xc7FhYBNMLyUWMLe8eBs2D/ou7oJt\nU8liR0c2xyxCRn6PLN0kyc8SbM9tx5cfvwlPrH8hMNwEy81wulqhQRM6/4E4t+KhR6WJX2rp619a\n5PXYdBHVP1TwNQXYgpiXBZaSmWx5bx7ahqU7yrPOf7Xy9/jNyj+Iv4u2G9B8v/Wl/8Y1T92EXfle\n/HrFH7Az0yme9VIZ1NV9a6HX00UQJ+A1fesD+6iJwYDbnH+XJ2nYc0+G63riPnLIev6lynUyIvKO\nAADQVPYjUspdedzyVlRPuNJrKmJCwjBgeTOrNpDMEaKqxlGZNANxW0teCDA0VfrkbmoGFEVBY7Xv\n9pXJm7+EYqbGXKBKILNazlavTtPjWpK+znhpkhx3hauJId8t7lLL2zTUACnpsWLgHKauBRc2jNhM\nzYSmqXC8oOXNyVtX9fLYqpQ3YKlWYNOpI09ERW6cv13zr5d7PbTa7XijawOKjg14KtobZcvbpN4F\nkDLL2/OISNrjpMFnvQAAIABJREFULwlFt0EcAx4jb3gaFIVlC3P3eWkSIRfaAc1GB2huQNzS/fun\nOYiZmp87UIhTS5yoILaBIjJSzLsI4mpYsqwbdz9MXZM524/r8nixl6mE09kORSEYcCh5X/Ghqaiu\nCU6vmhhCwaRuV8NQy8ibXoODdFJ+PoOWqNEuNVlhiwsu7AJQFz63vF3J8k5qKcQd2kRGUT3YbJ4H\n7SHEtBgaN58HZ/toAIBDbL8Gmn1Hykj4vyvFQ6Ho4q3e9cg5eaSdVnj9dWIMXUO+KtiWHuZKdw0M\nZei4RJMdRhp6y1rE5vwdD+/6LYrI4c8rngxc85Cdpde1fip01RALKNntrbe+BcegnhAufyyseOn3\ns9T5E17dvAFf+q9nxCLHigWJqbQ3Oq9l93rpb3ht34bAdtlb4HgObNfG13/+HP7th/8Sn29l5ZM3\nPPtdvLDzZfxr6xLkmOVdCHNJMw8cH2NNrDqwWUkMBkrF1vZuxuceuwbLu94Qn8ltb6EHr4mrJAI0\ncY4QgnvefACvdvnhCdvx0McSE4frQX8gEJH3IYzm2gROmNWKL15whPgskE0eAqOEcDVVCcgbkuEs\nbwC1lbEAoafiQWICgNYaP0lNjudyWJLb3M821/06TEk0hYu4AEBVih4XsFRLM+Cl97OaYGVWsuWt\ny+QddJsbuio0vGWYqgFNU6h2t7SY4Mk3pmqUkbec9CeTsxibNN+xEMtbjWXxt/7f04Q1TxPudtNQ\nYeq+Z0V0RXNMZPIOtSCY5e1fqA04BmzH893mAGJxlLnNP3bqBEHufFvOy9JnytMRs3RhvampPjqn\njLy9giU8EKQYR8YblFzGRX9O2Hf15XwrLONkaEzZMcR+3EqLWRo8zd83zix/T6X3z9BUOGD3UnwH\n/d5kXAqTlHTVcwerpG1sIWeb8DIVYpz8he95HqAX4RXi+FjHlTAddqzqsg531PJOm0nYDoSHpugV\nxQKAex8qrKT4XXHPx7Iu6vKOF5vhDVXD3joGAPDcm742+NZeupghjo6BQWZpMsubX6Wa3hUoAyxk\ng7+/jJ1hc6RAgyFCF778bT/05vVAMQEvkxZiQaVqfRxLNvtlc+NHVOH8U/zcDzW1C/IP8ub/eQld\ng9TFbO8YgYQex/qSJjy25xPj7ct+gy8/eQN2ubQpztbuTKjVammmkKQNI0au9Oc4dCxyxjtA3d6y\nbsCDq/8BAPjjW38Wn/FjHdcLvEMAQElI5J230Vvow5Nbl+Dny+4Sn+8azIuZiCzvCACoxfzRRRNw\nxBh/tT4hdQTcwSoU3pwdekwpucdMLaiQJJM3CZKZkDdlkBOpONpqq6T9yxcSlaZvRfKXWMyULT//\n/LpE3tzy1lTNJ9mSxUVx3RHwWMIUb0nKY96moQlXOOCX9nDi0jU1IBwiX4OuKtjUOYTfPbpOfM7L\no/RQ8vYXKLy0SoYWIG95MVOSw6C6UIkuLNiKhCnlNDjCCiCOQd3mhJM3LwVzoai0tj5XcFDgbnMA\nlim7zTUoACxDCyTNAUCB5JDQaC5CwtIRN+j86rU70KmsgWKypKd8TMwDKcThEgdEp5nJil70xXDY\ngi3LuscRQtDv7PLbz7Lvz7t0e6ezCTuTVNr0kxMvw1ktF9Ahs/71pqHCUYKVA2JRKB5PAhhFjEx3\n4OjmuXRqpC588iKosHIezGKNyDsAaLKiYhQB2/QXSACgUMvbIx6G7AzSZgq27Sc2Op4jkTf9jsp4\nWlLCo9v6cixbv0jnmeTpwAekBc6OPhZGcA08vIS6yGWyo99B5X7zrx1L50hx8M+XttA+5ZkiuocG\nxBzpii5i5tzw1iq7oShAYeN4EMeEogDZYkHEkkvj/tttP8FLU5WA0Iw1+XmYE18AJ/A1W/rx6rrt\nbJ4NjKrsQHd+FwaKfqWCHMte2fMmrftnv+MbfrEUP/uLb81yFN2i0DRwPSJCFv7AWNWJ68sJK0RF\n7oVF9JqsbMBtvnEbnfNdhV6YE16A3voWc6F71I3O3iFpvQLENpnbnC0M8g5eWeu3C/WIhy2dQ3h+\nVac/3ihhLcJwiGtxFFfNg9dfH7q9lLwtUyuxvOWENXr7501uxPc/O5+2/pMs7/GtJf5MMMuCn1sr\nt7wbEv5Cg1tIPGv5e5+Zh7PmjQ0dK7e8AYiM5dIMYJJP4SOTzg1+oWR5u92tZePxX8R+9q0MQzWk\nemH/+3gs2ND0snri0fX+NRoh5C3PtxJI6C8JA6gudMUQZXQVSRNtKRqX16q6AuSbZZa3r89OAuSe\nLTjCbQ4Ahkl8kvc0aBpLAJOS5gACm+TAeSNmakjr/uIrhz5R1uQWLDEPXoH1ZlcGac9yzRPhEL5Y\n4q1fu3O7YKMgLF5ueXsKJYo3B/3yqE1bbdz10Dq2ncdXM7Br3gKgCNLjC4BYjL2U9SIUBUjoSVw0\n4TwoTgyKbqOphu4vXP/FGEBUmAodP7f+O6uepIsgT4Preb4YDot5d2W74REP9fE62K4nFp02sSWl\nO2Z5mwkpt4Fu4yEEl1Vf8DnK2QUUbRe/fGgVVm1hjXpcXWyX8wb4dRLHFGI7ikoT2VZv7sMdf10B\nWymI+VWhS25ztsBgdegkmxZz2DkwhKLtQW9eC60mqKrW7/o1y6qqBKRhAdAOaZK1nnPZ78s10Jyg\nZX49OT80kA35/cmu+tfWlau6FdyicJtD8VhIyQfPc+GLqEwxA89mioKuCkVzgqV10m9Qq+wRTXgc\nhzDLm97Hj7Z/Bl6mIuClsl0Pv3/CL9+7c/n/4DtLbsMDT/niMZHlHWFYqHu4Y2aJNWwZWlD3N1Aq\nRh/kptoE6qvi0FQ1QO6TO2rF/08fU4svXzjdj8ejNL5OURvzCZ94Kl08MJd5Q3UCR030CVa4iAGk\nErIrmrfwLL/YMXV+TJx386pImNA1Bc7WsdQqIeUxfuIRVFoVpaeDqRmi5EgJqXU3VCMgvfjZD07B\n5R+YIf7Ww8hb2t+TpEcntNUFd9RtqIqOuGR5z2+eCxAFesNmP6Pe1ZHJ2zR2yaw6a+ozfptXx0Au\nHyTvbGIjI2h6nzVNoeTLXtqJlEvdpooHz6afxS0dST2Jwlv0+lzFFpa3V4gL0RauVpbxBkUSk8hl\nYN+/dscu/Owvy/HEm5ScSYaFW3jZGrdwTN+789yrg4J8XOY2R7IXUF2MVuaCFBOB73hsgJYUcosx\nriawoycLt6hDt1xRiREgb0BkxOftIjziIR9jFmM+IfIKAFoOt2pjL/64lNbzt6aaqQyqwlu32uVu\n81hSkD/XyOdVBbativtJPy/gsZe34ull27F5Fy2Rq02m/aQ/j7vNFYgcCFsqeWT3dzBrY8122mKX\ne0Dyeep2J4SAe43VWAbEU0AKcXGN37l7KdZvH4AxgvUzcHTkXliEpNMIBzZ4GZ+mKgErmkP+zdhM\nuY84hig5zEreroxdImID3+1N57A8abDgFkTCGkqSyVJGUhCrIyzvrK+q6OqA7gT7juvhSWe268F1\nCfNuqMgXJE8Zu8ai7QbG8GrXcmjpXvEbEfu8S4jI+z2GcLe5VPIVEvMOWJYSudek/Bfr/KlNmDra\nJ3MAAUEYDpnc4WkBlzk9p2+5VyV88RiZ8MQChBGVSGarS6IuLnkDPA26ptDEKkUBoIAUEtAhuarZ\nS8ojBIs6TkBV/wwa72OQyZk37pARqJsHVZKrkhYBCb085i27zT2pfj5VojKnKABcFQ3VCUwbXYu5\nExtQHauC5VZBiQ8FMuppqZhfh6omhkQmLHENDOVtFGzfbd6XXI7p09l99VToqsK6zrH5GPMMVNZb\nmj8TBduFoijwhmhoxEYOipmHQhTAlmLezPLut/tgxIPkzc/Vn83h+VWd+PsKSt6lljePLTpMMCb/\n+kLs7CmIuHavthHfff5HUHU6B5br51qIlynJQzFzIt5tKQlk8g6Ia8BVCjBYtUYpefNcjbybD1iD\nzrYxrByPC9FQ8n19K81gbk01U8ubu80JJe+EpQuXdtKK+d4PI0jeXPdAdPDzCsiKen3672lzx4rf\nnS2XWqksROKYoGI7qng+sgUnQJwAkMl6ICBwPMePeceyIIUEAFXyDri49wnfclR0ByAqFM8MzLWm\nKqHlcmD3UYkPQq3sogtqT4NC6Dhkb1fGLre8lVjWJz+jnFjzbkGUivE5cnc14gOpy2CqlvjMcT04\nnoO8mxeqisTVoaiOmGN6fSXfwd3ujkcXAJoNuAZts+zySgJ/n7LjpTkChkmqO0CIyPsww5566IZa\n3oGYd3mdNycbz/MCCWvyQqBUfrB0eygUglhJrbpsrVt6+PG8QQh/iR49tQkfP20irr5oBkzNpCtu\nNv50wixrSiCTN38RVqUs6KqO6sKkQHtUUzVEOdCYmjZ888jrMK5q9LDXaGhqoJZ9XHMdPnDkCNx0\n2Vz/slUFhVVzoearsKDZF27hgh4yKhMJ6JqKL104HfOnUq+CiQR9YRh+0h0tFSOB8h7enAWOgf6h\nYHY9APSxzm10kaMGLG8A0OtYwhT7bOG0ZurZYQRAyTEPnSQAKNA1BQumNWFEFQ3Z7Mr3wUqweHKJ\n5a0YBVimImWrJ6Brip/YptlQFF+qlQghGVVoxW8Z2ibmwHN8qV23zw8ZKUZRWN4WErSch32HbrJY\nrsVkMJnHIGXRf9/Y0oVfPkL7NDudbSDFOAtNcPJmI0pQi7Ml2URj3sxtTsnbRdzSka70UBmjrV1L\nyZ+7r4u8RxD7DRb1XRhwugHFhcZEgyqsBGrScRAiuc0VOW4vJe0x0ti4Y1C4r0lJ7kHRsyl560W6\nwGChB+Fh01ykkiE04AYXWSqzvHUvjuKGSbC3jaLbmSWqN9MFDsnR3vSKS8chk3e2pH4bAPSGLYjN\noNnmhuWTYFKpggIFBafot2FlY/EKcShOnC7CJLc5j6kHLG/NoUTMUGrdK6oHKJS4HZewcj0ahvIX\nOPQ7MnlbkLfT3QJnRzubA7o9bmmR5R1heBT3QN56iaVoGcGENeLJ2xXpv1wmskSqkyGsLWCY5Q0A\nY6voD5vYFsa0BF3VMtEamoFvf+JI3HrF0SXnpeTI5V1jpoZjp7eIuHiVVcmuRRMNW2QYEnlfcfZ0\nfP7caRjTSo/RNZW9YPh1aSJO1VSTQGOqBhWmbJkH57M0/p00E/jwiePQ3ugfo6kKvMFaJDedgOqY\n/7nc7IGjpabc2rcU+oJVOem4GnIFF7miCy3mZ1VziyWQxyBZ+l051vCFuc0NPRgWEYlVnobPnD0F\nNRUxen+IBuJqKHg5KLotXsS6ruITZ0zGNefT+9WT3wU9zsnbEucCAK1yF5qmv+W77l0dLbVJsVDQ\n67Yj1rLJF3MJkZYFAM9gCnBF+twkYjq83iZUDbIKDL0o9NKTajWyeceP+zJLTjGZNeZpuOXf5iPN\nyHv11h68vtFPsgJo8ppX4vZWrBwsNYZ7/rYJBP4C1CU0YU2zCsg4Q+iobGGeJhWEKFA1Rt6eDVM1\nYLOsZrHAqejBC7gPWsNmaJXsGow4aiuo0EvOLhey4fFuIkkFr9vZi9gUKpzCFy5y3NzziEgMEwtX\nISTjoChJ2Y6qYhnlJeENVSXoLw4grqThdnb4izUtaBUX3jgyMA5ZMjUjZYIHcmMAQPEEeRc3TMIM\n71xYmomiWxAxb1GD7eooFF0YqinKHh2X+Cp2IrFRh6J5yBSkeWTkrQ42wB1gZWWqC9vxqKdDs0Fc\ng3lwuOVNv38wa4v5cLvaUB1jybvsGY9behTzjjA89mR5lwovWKaGie30ITthVmvoS5KngHgEQQlR\nibiUEPIu7fTEccX0T6Bi2wkgmUqcefTIYcdqqDra6lOoqYiVfS6j1Hr33dYkKNTBj5dUz2pSScwc\n71tqhqYG6n0BX7iBZ30nDD9cUGZ5l2Sex8JKxbgbnhDELR1nzO/AvMmNWDituWxfSyuPmXPyVrhl\nzd2sRRfElPpCM3KXQx2Vdf6LWGQrexp0VaVubzlhx2KuVtfXntfZfSaOgSFniMqzshc5d5vH9TgS\nehy7cr2iJI+/zD+4YIw4f6eyRri947qJz35oqug0BwBoXYmCW4ACJZDMJ5frOBq1evM5Rt5snKpD\nCXj2tCTMup3wCnGkvUYqYcleuq/iL4DqUPJm46urjPueE83P6OcvfM8jovWqxvu6aw5yWQVLWJtM\ngy1aHTh0MZ2gNdod6XamSgfAU6FoLJud2DA1U2RNBxfQ/iKNz21N2gI8Ddtz23HHsrsgMvqlccLT\nxMKoM+tnO8ulcIBP3rwlqli4Sm5z3oa3Wm3CN46/KuCh4ffC1bLwiIdxDS04YVYrxrXUse22P5eA\nOG57Fx1vgLyZZXzVjE/jqhmfDswBvRd8gRLDUMaFpZk0YU3EvNn8cfJWTJFQ5rgeduWpp8kX86H/\nDhX8MSi6DS+TRmbVLH8Bwo7f5qyHogDeUGXAbS5yC3K2mA/iGJg7voWek2kimHpkeUfYDfbU67fU\nhZyMGWitT+G2Lx6DxaeMLy9XAkRrPyrmIFnGEonK33v6yJNRH68NxH5lWJqJq886ATd8bI7I+JXB\nm6+UeglKt/NyH8sILjgqmeWt6DbSyZBac0XOXA9+h6YpActbBpf7TOp+LH5PlndY9yQeo+eqcecd\nNwafPnsKmmuTqDQqA/uGldvFWekWfzEeN82vr1WkF7/O483SgqxdnVZ2PuL6lrdcIy7I3/W158e3\nV2H+lEY0pqtEaRBPaNOlhUtNrBo9+V7U1zOyZy/CU+YEdeCJ4oB4KmaOa0RTTQLfuuyowPa8W4Cl\nmaiuKF8EAUBRoyV7OUbeosENK9dzk53wFAfurkbs7M0hm7eF29RGHlrtdroAkcSBYixPQdHcMnf0\n8nW70DfAXMUa70jmBBZIwuOkuFTVLk5Jo6OizV9oen5M2iU2DCk8M5yXAQDqE3WoqYgJ1/1rXcuR\nJf2wJlBJUd/y1oU7l3sv7O0j4Q0wi5aR8zPbn4dLiBAb4QtX/syYY19FhtDx1+ktSFspGLoKj7e5\nZZamrdHjm5J1+OiiCWirpgaBYmVhjnsZWsUudk56n555lWaqb+rpFdfG3ea5IR1JI/heUHQbnsEs\nZ9tAf6YIS7NQCIl5wzGQLzq+qJLqoug6uH/NX+k1MhU7fo0i1q7QOm6xuJcaBdmOh60OFW5xu1tD\nLe+hrB2o8EjH6Hvig8eOwHc/PQ+WoUUx7wjDY97kJswaX4+vLZ4Vup1nVNdZ9Zg2uhZnL6Qu7GSM\nJWaFkTezvUvbBcqiJvKmM0Yvwk3zrw0mp5WgtjKGUc3h5M7JebiYOd8u9MdLkt7SJn0BKbqDdLyc\nvGXLu1RIxtBUv+SoBJwYZMtbLyFXTmAfnXQhxlaNwsiKkqYlAM6Y34HT5rXjU2dNLtv2hWlXBWr0\nwzL2E5rk1lc0jGzyCb+m+xhfI54n+kj3tEFvxw+O/ffgCT1VinlL99RgbnfPz3jXVBWfOmsK6lL+\nvXOKLAFLWrjUxqphezZ67R7qvuS92y3//GkjhYq0CsXTce6xNI9A04KLy135PliaiY+dOhFnzO8o\nmwvCwgCdPVRHnN8jTmI7MszqtC1s685Qy1tWA+WhBdvCSbOobn/CYG1Nx7wu+otzwn9pdZcIJ2ga\nK8nTnMACSSgPcnI26QJjRLoVlsmS+jwNHmhfe9uz0dMnJToNoxxYvXURLM2kSoeyVCk2+vMhW95C\n598RcyD2Y8f/c9OT6NPXQIlTD0ap5a1oHvQxNJuee5Fk8halWCo9vi5OibE6Sc9jtKyHVs3ugfQc\ncm8N11bo7suhM0sJ/Sf3vFn221fTPSC1G0CKFrxsBdZvH0BPn4OcUxAiLVzbgdgx5G0XGgwxxgIZ\nQme2G9PrpooWvdzyzhTZ74Qt1OIaj/v7mgeO6yGDXpCiBZJL0wx1fs/ZImkoZ0uuewOVTGggFgcq\nUxZMQ0XRdvdoYO0vhJs+EQ5ZWKaGK88tt644UkYS35p/HSrMVHhMOqS1J3/Zlbb+k6340pZ77wS+\n5R1O/pogb/riLiXvuJThHeY214gpvqd0gUHJV8XE+Cx0NFQFtnELf3duc+5Wntc8B/Oa54SOP27p\nuOD4saHbUrFYwPIPu0dJPQk4/vbqtH+9llcJe/0UWJOfB1GZFSBZhZahIaZbSOoJP8bo6dBUBTFL\ng9vTgqJuw+zw5SHh6mVd32TLyCkyYpeItyZO44V9hX5U6JXg7RsURUFh9UxY41+B7dlIWAZqrKQI\njWglnouck0M6UYdpo2sxbXQtHlqyEYWVR6FuyhoMEhazJ4Bjq+hoTQjPByfvXqaRbqoxbO/JImbq\nouc44MtbLpzSjounUNnadExanNWzpL1Aq1z6Ha45CKj1rCpAmmON11mzeD57oSeNhL/wJCo8uEhY\nOlzVLbG2gwsYTt5xRp7phBH4neaJn+XtDXBi8suYeLlVIJ9B+v+COgjVGoKXjwsPguyB4z9z/rsy\nNBW9/QRWo+/9Kap0DPUsVl2TKM/VCHw/I/8iyeP1tT348f8+g9gR6+EO1ACuIQiZQ6vugqIAxS3j\nAU9HvujCLChQzaLoC6ym+kEIdWsXii7i4HF5FzZT4RvoKxeEGsxnAZjQ0rS6okKvQhcQ0DywHQ8O\n8ZUCs3lb/K54eCJrboNVxWrfPRWVcbqIzjssVBUbgDbiDeSKp5XNzYFAZHm/B1EXrxk2mexblx9V\n9hmnZbIbgi61yt8JRFx+mFOqSnncXkZCcmvL7U55rJlnm4dZtfzlf3TNSTh7zKmh35PYjdv8ncLU\ntYALN8z7kNb9a7JUU7jhAZoMN29C0NqX3eomW4BUxZi1ThSAKNBUBcmYgfqqONydI6ES2UrSAhYz\nAD+jHxAvYpng5Xr+ilgSs8bXi0XloglzUUmakXcLyDm5gJiPripCHpQjVhL394aqMdM8xf/A1QEo\naK5Jipp8txict7pUBTp7cxjMFuF2t+KktuMB+PKWDekKUXWRNMt/G7LkbYJtH4qvg9FOFzky2QkJ\nYJ6Mp9iIaVbwuXV05L0cYpVDVMSmpLJDBifvGHvuUnED8o8jD2r1FlbPFORbEaf7ajXbpQ575RoO\nAPWsKUYRhCWr3fjxuThnQfnikqvrmYYq7jl3mxdUujyrZ5Z3Q0U5eZeqNxJPRc7JY9naHiEA43bS\nZ3XZup7gHHDPgOwV83QoCi+1I1CT/SC5FFRCyZ9b3takpbAVapWv3iDVkrt8AcF6rTdSQZZRFvOI\nSZZ10XHhoCg8BnLuBPds8GeBjRhxg3fQo+SdSa6F0bwBm3qD7UoPFCLyfp9hREMKN867BjcvuEF8\nxt08Lvt3XP+5uOGoqwPH7U/y5pa1S8KTO9QSy7s05t2UpOpN9bHaQO05fzlzyzuMGPnLP6xjG/+e\nZIjlffHJ4zB+RFVACW5fQL9fETHN/kJ5b2ceFgAASw+St6oquPC4oDu+QrIkeQ9s/pKloQefMK5f\nPBvzpjQibUrk7Oplcyxn3E8f2YxPnDEJx83wBXZ4xj9A5+vKc6dhFksMvPCEsRjZQL8/5+QDSXma\npsLZOg7FdVP9awxJ2otp/vg4cTbX+Za36ypI6v51N1VVwiME67cPQFVUnNi+AIBfTiff0zE1I/zs\neAauugbQOefQG1g3L4mYDFWjiyJmeXuqLeLoHPa2MSDwUGxaRj/gmvNmubdJMQsgRKEd5cDIW/N/\nG0VlqGwMHfXU82GOXAU11ReYJ/n7AKDIiJfY1LXb0ZTG+NagZgPgaxYQ4ru9eYJWERkYqiEWdfXp\n8pBYaSIeHAMFL49ETIfCeoDzkM+ytUHyVpmSn+w14Za8Yuap7oHmwstUwjI1arm7vuVcTG1mx0jN\nhQIxawIt3Y8RqTbUWLWB8Sqai6ydp78VtmjpzxQDx1tSCaCzg4Z3+D1/bPNTeHzz037mPQn3KO5v\nROT9PkRDog6VVvnKmbvGLSWFpmRDYFtIXtY+QxXkHX5SlcfaWcy71FoZXdmBzx7xcXxp9hXB47ik\nNCPv0pp3gFs1VIq0FMJtHmJ5nzJnBK67ZFawZn4foCgKrr5oBmbWURWz0sQdAEiYMVHrbGkmkjFd\nkLKmBUkLAJKmFONn19DMFjgcfOlVmbLw6bOmoDImddjytLJEx3qplOeoiW1YMK1ZzB2AwPOTCLsG\naQ5ly5vfS/klHUbeimv6nhNO3jVJcbzreUhLY2itoZ6ATN5BIqajwkoHqiHkMVZYaehvLkL+9YXi\nM/k+xIzdh5scj8BQLKiJQSjxAXhKOXl7fY1I6Wm4Zl/g+NLKCQFXF2JKybgRVC5TWaxXImdu9QF+\nXH+4RLiiTscwsq4ON1xKQz1hiZIJk96zwWxRsjqZ2xw5VJgp8ZzIz5x8DYCfsEk8DXllANvct0SN\nNo9Db9hRrtYG0JJD/qyLOTviaRH+IPkkYqaGfNHBUM6fI0/3NQ9Kx2N0rGJzRFATr0KMe5mE5e0K\nsR5O/rmCFPPWHDom1QOxTdibJtFxSc/tfW89KMJYils+twcCEXm/j/GlC6ejrT6JY46gJQ+cvGWC\nGsvqoxtq4uUn2EfwOHRYpjbgW+Ya++2kQmq5p9VNLluAcLe5RuiPKszyPml2G7568UyMaPDJ6+On\nTUR7QwpjWqk1EbS8939ayOSRNbjsiPNxycTzsajjxLLtluRaNzUqQiMatygKNFULvDiC5E3nroy8\nSxwnSSNoeZeiPu6Tt0zEHBVSA5rQBUiAvP2xcs+HHDoI08h3XCI8LPwl3taQFDK6iZiBasn676j3\n3fiJmA5VUZHQ/WssXfAYugqST6Hw5mwU101FXaU/3ngIecvEWSi6mKgfDUVzoTdtgIci4iELkLRR\nCaIEO7uVhif880slmpoaIG9P9ZOkOJRA1QDvXS/FsaUua45OiW/e+A7RwS6szDNlMNlbqVaeyt8S\n2MghLXljShd79BqY5rzJa8jpta8k//TVANl5t3aXS6USxwCIhoZqdi8kmeOWDno9Zx45HjGTajP0\nZf2ySRKKQnq/AAAbb0lEQVQbCJxfHo+i29CbN7BrTIjx8Xtijl6OdblV9CBXFzkAckKbodN7Egif\nlNxzy6I/suaaYEXJgUJE3u9jTBtdi29/4ihhhXLXuKym9qULp+Nri2dhTMv+eyC55T0cefPt6YSO\nb19+JCrfpqv6iDHUHdZWzVyKIdZFzNQxqaM68PI5dnoLbrr8SBh6iOUdco79AV3VcXTLkaFjNAzV\n713NSJeTN19Y8aoCUzUQtyTVOmF5S33R6ZkCf6UMyfIOJW/frZowysm7UnqRlxJj6THyNXLrkkus\nAggo1vE6esfxfCEPRkS1FTGcMa8DC6Y14XPnTEW15Sccjm32a/m5cE9ausbSaxC1+P31cLvbxLMD\nhJO3vMAp2C7aY+MB0HI7opAyyxsAKqQmL9yKa2sIL1NEqbWmlv825PvUW/Sbhvglf/52N6RxkRyO\naU+34YNjTgsQZMqS480avEwaWsUuaPVbQBQvcDwAjKroQEqpFp2+xjbV4rR57SK0YW8eL/bV0n1C\nOnU48FBGQzUTKUr6IaW8QRu3jKqvQ8LSkc07yG0ZAQyy6xS96/05mDLC9x7yTPWkkfS9H9K+y7LP\n0jE4BkawBQ4kt7mha9QL4egY1ZzGly+cXrbodPUcTM1EOvHOQmtvFxF5RxDglrfspo5bOsa1VQ13\nyD5B3UPMW1jm8IZ/2YXg46dNxFXnHYGFU6hs4R7lW4eBoRnCZbuv53gnMHVNvMy5vCTPOOfZ2jUx\npg6lKIhJ8WruNi9VsCq1vNvSkmBMyAtVJtQwy1te1ISRe3wYy1tkrEtWolySyMvRbMcT18jjoYqi\nIBEz8IkzJqO5NonqmL+gbK3zn9EjxtJrr5LIPVGywCgNxUwf689XmEuYuJqwVfNFF0nTAnE1EVMP\nI++URN7cyjv32NE4//gxZfsSV99ziZFENos6ThL/z/UQZOudZCqRe/HkQGy/QiJfRVGwqOMEJGzf\nQ5OWyRsKiutoAqLesLnseAC4es7nME+7UCwARjdX4YLjx4pyPrenFflXj5PGT3uNDwfujeGWt73N\nn6ch0Bh52kxhVEsFXI+gp9dD88CxwXOwRe8lp4zHF844DpW5CQAANUkt85SRRJznHYQtJFzd98rx\nksGqbrg1a2jioWvgklMmYOro2jLvQ09uV6gH5kAhIu8IAsfPpAlJs8aHtxvdXxhTORJAmHVIMb1u\nCh1P24K9Om/M1DFjXB3SZhKWZgaSrvYWPEZ6INzme4IpWd5claqmgvc7py8MTmxFtwhLiqNazHug\nqzqumvFptPfRspVSWhhd6ddUf2NxeQWCjHgIecsItbyHiXnLXp20Qq1dXv8L+Ja37bpoSNDnkBBg\nzsRgDgYAVPMFDILiOVzJri3V4o+x1PKWyPu4GS1oqfWvIWUl8K3510IfkhY4ro6PnEItyUVzR9De\n6LYpuqrFJaW9tnrqrm9IBpvoANSDcvq8Dnz9yC/jxBHHBM4vo3YoqONAXE2QCQBMrh+Lr5TkfJSF\nPzxdlNQBCP09aNLzLSc+AsBXPngs4GqC+NIhxxMoYlw8h+Wy0yYKWWRSjAnLXHZpA0DlllNwztgz\n/HOxPIiKBPME9jYF8hIA6k2Z2O7f95baCmiuf2/5d+SLDlRFRYfH+ruzkreUZHmHClY5BkYIqWMF\nbj99RovJbfQjVy+rfvHnItwDc6AQ1XlHEDh9XgcWTmt+227qfcWF4z+ICdVjMatxeuj2SbXj8b2F\n3wyWK+0FNFXDNXM+H3AN7y0Sehx9hf6DYnkbuubXm7JabRHzZqRTKxGX/DIxpSz6CTVjYbkZAD1l\n7D0i5WeOj24O96x8ceZnsHFwS5m7tBRhZYmyNR6WkAYA7eZErCg8g+aUb/0J8nY8zG2cgXV9GxDL\njMTZx00qO77GCo77psvmYiBbFHM1qmoEsCV8DHweZ42vx8dOnRgcu6WjLl6LOSNH47lupn3u6Zg/\npQknzaZCL6+s7mJSpdQzInsqvrZ4Nrr6cuhVN4nvT5oWvvrJo0TYoCXVhIUtR+GxzU/R87s6IHHC\nCGUatrxYg/icf9APSohGVYKJi6RE2lh87vj3Jox8PdZ61elpQoKFX266bC5Wb+7D5JE1UFdUwovv\nYseXPwfHTm/GP5dyOWBK3o01CXz90jlYuWEX/vPe10DsGHXtl1xDtVGHk9tnYkzlKNz94t+xsZMu\ndqaMqsH0TbUY0ZjCX5esByGK8C6kzBTGj/AXXifPbsPmt6rQyYVYuDgMlzw2LZCiJRZZKTOJmDK8\n5U1cHUdPbcIf/klbpBbfnIPYnL/BNQbE9pgxvOt/uGf9QCCyvCMIKIpywIkboC7Go5pn79aqTUuZ\nrfuCpmQjUua+kT9A3c5pI1VWc/5uwNJV1roRqGSJYaUxb9EUAcGyt9KSLz6DpIS9Dc3AhOqxGF9V\n7sLlGFc9Bie3Hzfsdo6w8IdcBx6WkAYAU5Nz8NkjPo6zR/v19kdNpkQ+sb0auqrjkknn47w5c0Q+\nggzZbQ4A7Y1pTB3lx65HVUqysiXPkio66ZW7qrkVP76+zf/Q1f0sZQCmqQUy5mXhoLilo70xHcgb\nOG56G1rqgs9jXPYGuDpOnOV/XzpBVdZ4YlmYlRhIOvR8adLW+iTOOYY1B5LqpsPCHzWDM+H2NqCt\nOE/McXtjGiczmVut6Lv+wyz3uso4KhL02r2S52DyyBpceMJYv1lKyTV091PCHVXZjqnG8aLne1NN\nAl+4YDqV2iWqyI8wVB2WZiIVN/CJMybhyxdOR3tjGpMa/C6AHz91Mlrrk2KR1TdUCCRHUstbCx0P\nAMAxkIobuP6js3Hy7DYACohtwVN5A52g5X3JxPOFpxAID58cKESWd4QIIfjY5ItQcIvvaAGxr9B1\nFW5nG2yjgCvPOx8AUJP21a+AYDKWXH5klpK3SJ0t/56rZn66/MO9wHnjzsIDax7CxJpxZdsaE37o\nZTjytkwd0+qCNevnHDMacyc2BKoBhgOPaZdm1nPwpD7e31lGW30SmzuHUFc1/MtWvi7iagGXv6Vr\nAZd02Eu7OdmI5mQjtmd2oqmy3LshW84nzhyB8SP8fToa0wAUmEocBZIVVutpR7ULyeOEEYcChS7M\nJCI6YWYrTpzVhhNnt+Hz/1WEmuqDSozQZ/mCo+bi0aWNWPyhCaFzoBerwIVd08N4ssZVjcZLna/5\n1QESqtIWyFZekkUt81HNVP50guT+lhdn3PuSZImH3kAN1FgWtudn4C+QmvzMbZyJf215BgBNPj12\nuh8uaapJYHl3HGqKJr8ljSRivN+BY+L85o/h3rcegJryLWuAVtmMba3EkZMace+W5diapS4U4hgB\nsaKjW47ElNqJeK2b9q1/N2PeEXlHiBACUzOHVak70DCYhKuzdRzqE9R6a29M4ZxjRokOaTweXB+v\nDVjbZoj4DDCsmN07wokjjgnGbSXIRCFaNZYgbKyqqgTaq+4Opmbg5gXfGHZxAAC5l06iL9sPBD//\n6AcmoK0hxayr0rHTfyulxjtfuXB2YB/TUAPkHQ/pLqcoCq6Z83m81rUC0+unlm3XVA03zbsWf1n3\nMOY1B88/aSQlNiPXgEJsA5QEJRdNU8X9VhUVcT2GrJMrkTv1O7DpJIbCiqNRlQ4nlTEtlbjinOHl\nluOZDgwZW2FU9pZpP3AsnnQBptVNxuyQMFgypkvtR6llfvyMFpxz7KhABYvrlmfX88WS21/ni+WE\nYGTFCDQmGtCSKs+h+dAxo5B5eRJeGKKqZykjAVPKjxhd3Q4vlxbkXZr1P7atEvW91YK8DcUs03pI\nmynoqg7HcyLLO0KE9zOSMQOXfmBCwPpUFAVnLRgl/q6NV+PauVehxqoWtdMAy1QPwbvUKyGARR0n\n4G8bH0d7upwggXIvwb6gcpjOdhy3f+kUhDlPYqaO044qb4RSio9Nvggv7HwFExpaA5/HTC0QTx7u\npW1qJuY2zRz2/PWJWnxy6uKyz6tSFlrrkujcUgN97Aa/R3XJjeTiIh111Th/8Ww8+vwmHD2Fkpii\nKHBcD4CCUU27n6fhoCkGiqtnY+KY6mFzH3Z3jRPaq1D7Vgp96BYKZJapBcIbANA7SGPSFSG9Cnij\nkXFVo8u2AfQ6bzjqK6GeBUPXcMmcU/DCE4/T79aswH6WqQXaKB83pTyMJHdPTJrhXRJrrCp05roD\nuQ8HGhF5R4hwCIJn/u8OYaSol3Tt8t9T7z57nz36VBzVNCvUnQr4mfEHEqX913eHay6eiQeeWheY\n+yObZuHIpvIOfg3VCZw4bQyeHqB61/EDYHGdd/wY/PTPWRQ3TII3SC3x0kXYpJrxWLVrNU4ffRLG\n1ldibFu4FT1hxL6Ve/L5U/YxPUpTVZw//Rj8YvkGuF10XsMWmJUpujiZPKqmbBscE1+c/CWMqKsu\n38awu/CWoer4ztHXI+fky/bTVQWktxVesg+jnWNw6YfKEyPlBWKlVU7eAK3+6Mx1v6sJaxF5R4jw\nHkLpy4n/fRAMbyiKMixxA8O7+A8WJnZU42sds/e8I8MHpx8JrO2GSzyMZuWP+xMzxtZhbGsVVm30\nPQSliYeXT/kICm4xkMAYhpHDtOfdEy47bRLufvRNXHRSeV7D28XMhmn45lFfxdeefx1AeF+Bs44e\nicqkhWOOaA58fvnpk/Dy6i6MaWh6R9LE1bEqhFG/rqtQMrUoLF+A6qnhYYFKSU2wKhHufeClm2Hh\nkwOFiLwjRHg/4GCw9x5Qmhl/uCFhxHHxxPMO6Hc01iSwamOv+LvU8k4YiVBteY5vXDoHb27uxbi2\nfVNIbKlL4tpLyj0Pe4vGZD14NnxYuMTQNZEhLmPhEc1YWELo+xO6poqSjPgwuvNT6yahVZuADVuK\n+MAJ4eWtnLwjt3mECBH2Cj/43AK4XnnSz26SzQ869kfM+72OxupgedecCeHW4XAY3VKB0S37ZnUf\nKBxKizZFAbhBP5znPWkkcN2xl9Me4lY4ZY6pot6RltSBW2iUIiLvCBHeA6geJpt4yqgavPRmF2aN\nrwvdfjBxqLnND0U0VvtW9e1XH79XMfxDFeYwCmXvJj555iS8sbGPlX3tObSkKsqwxA0A46vH4gfH\n/ntkeUeIEGH/4NjpLRjVVPG26qbfbZjvASI60JgyqhqTOqpx9NSm9wRxA76lezBx9NRmHD2VWsn7\nK6fz3SRuICLvCBHe01AVBR1N+67xfiDwqbMmY/32gVDVtAhBGLqGr148fKnZ4YTT53Xg2eXbUfUu\nqDjuDfzQ0qEYXBoeB3Qpd/PNN+PDH/4wLrroIrz++uuBbc8++yzOP/98fPjDH8Z///d/H8hhRIgQ\n4RDC/ClN+MjJ4/e8Y4T3FM4/fgx+eOXCQBOZQwEXn0wz6WVltsMBB8zyfv7557Fx40bcc889WLt2\nLa6//nrcc889Yvt3vvMd3HnnnWhsbMTixYvxgQ98AGPHjj1Qw4kQIUKECBHKILvQDyccsCXQkiVL\ncPLJJwMAxowZg/7+fgwNDQEANm/ejMrKSjQ3N0NVVRx33HFYsmTJgRpKhAgRIkSI8J7CAbO8u7u7\nMWWK322lpqYGXV1dSKVS6OrqQk1NTWDb5s2bd3u+6uoE9P0cI6uvP7RigYcronl854jm8J0jmsP9\ng2ge3znejTl81xLWSjV59xa9vdn9NBKK+vo0uroG9+s534+I5vGdI5rDd45oDvcPonl859jfczjc\nQuCAuc0bGhrQ3d0t/u7s7ER9fX3otp07d6KhYe/EByJEiBAhQoT3Kw4YeS9YsACPPvooAGDFihVo\naGhAKkVrTdva2jA0NIQtW7bAcRw8/vjjWLBgwYEaSoQIESJEiPCewgFzm8+aNQtTpkzBRRddBEVR\ncOONN+L+++9HOp3GKaecgptuuglf+cpXAACnn346Ro0atYczRogQIUKECBEAQCHvNBj9LmF/x2Gi\n2M7+QTSP7xzRHL5zRHO4fxDN4zvHYR/zjhAhQoQIESIcGETkHSFChAgRIhxmiMg7QoQIESJEOMwQ\nkXeECBEiRIhwmCEi7wgRIkSIEOEww2GTbR4hQoQIESJEoIgs7wgRIkSIEOEwQ0TeESJEiBAhwmGG\niLwjRIgQIUKEwwwReUeIECFChAiHGSLyjhAhQoQIEQ4zROQdIUKECBEiHGY4YF3FDmXcfPPNeO21\n16AoCq6//nocccQRB3tIhzRWr16NK664Ah//+MexePFibN++Hddccw1c10V9fT3+4z/+A6Zp4sEH\nH8RvfvMbqKqKCy+8EBdccMHBHvohg1tuuQUvvfQSHMfBZz7zGUybNi2aw71ALpfDddddh56eHhQK\nBVxxxRWYOHFiNIf7iHw+jzPPPBNXXHEF5s+fH83jXmDp0qX4whe+gHHjxgEAxo8fj09+8pPv/hyS\n9xmWLl1KPv3pTxNCCFmzZg258MILD/KIDm1kMhmyePFi8o1vfIPcfffdhBBCrrvuOvJ///d/hBBC\nfvCDH5Df/va3JJPJkEWLFpGBgQGSy+XIGWecQXp7ew/m0A8ZLFmyhHzyk58khBCya9cuctxxx0Vz\nuJd46KGHyB133EEIIWTLli1k0aJF0Ry+A/zwhz8k5557LvnTn/4UzeNe4rnnniOf//znA58djDl8\n37nNlyxZgpNPPhkAMGbMGPT392NoaOggj+rQhWma+PnPf46Ghgbx2dKlS3HSSScBAE444QQsWbIE\nr732GqZNm4Z0Oo1YLIZZs2bh5ZdfPljDPqQwd+5c/PjHPwYAVFRUIJfLRXO4lzj99NPxqU99CgCw\nfft2NDY2RnO4j1i7di3WrFmD448/HkD0e94fOBhz+L4j7+7ublRXV4u/a2pq0NXVdRBHdGhD13XE\nYrHAZ7lcDqZpAgBqa2vR1dWF7u5u1NTUiH2iefWhaRoSiQQA4L777sOxxx4bzeE+4qKLLsLVV1+N\n66+/PprDfcT3v/99XHfddeLvaB73HmvWrMFnP/tZXHzxxXjmmWcOyhy+L2PeMkikDvuOMNz8RfNa\njn/84x+477778Mtf/hKLFi0Sn0dz+Pbxhz/8AatWrcJXv/rVwPxEc/j28Oc//xkzZszAiBEjQrdH\n87hnjBw5EldeeSVOO+00bN68GZdeeilc1xXb3605fN+Rd0NDA7q7u8XfnZ2dqK+vP4gjOvyQSCSQ\nz+cRi8Wwc+dONDQ0hM7rjBkzDuIoDy089dRT+NnPfoZf/OIXSKfT0RzuJZYvX47a2lo0Nzdj0qRJ\ncF0XyWQymsO9xBNPPIHNmzfjiSeewI4dO2CaZvQs7iUaGxtx+umnAwDa29tRV1eHZcuWvetz+L5z\nmy9YsACPPvooAGDFihVoaGhAKpU6yKM6vHD00UeLOfzb3/6GY445BtOnT8eyZcswMDCATCaDl19+\nGXPmzDnIIz00MDg4iFtuuQW33347qqqqAERzuLd48cUX8ctf/hIADX1ls9loDvcBP/rRj/CnP/0J\n9957Ly644AJcccUV0TzuJR588EHceeedAICuri709PTg3HPPfdfn8H3ZVezWW2/Fiy++CEVRcOON\nN2LixIkHe0iHLJYvX47vf//72Lp1K3RdR2NjI2699VZcd911KBQKaGlpwXe/+10YhoFHHnkEd955\nJxRFweLFi3H22Wcf7OEfErjnnntw2223YdSoUeKz733ve/jGN74RzeHbRD6fx9e//nVs374d+Xwe\nV155JaZOnYprr702msN9xG233YbW1lYsXLgwmse9wNDQEK6++moMDAzAtm1ceeWVmDRp0rs+h+9L\n8o4QIUKECBEOZ7zv3OYRIkSIECHC4Y6IvCNEiBAhQoTDDBF5R4gQIUKECIcZIvKOECFChAgRDjNE\n5B0hQoQIESIcZnjfibREiHC44ZZbbsGyZctQKBSwcuVKzJw5EwBw3nnn4UMf+tDbOscdd9yB8ePH\nCz3rMHz0ox/Fr3/9a2iatj+GHcDOnTuxbt06zJ8/f7+fO0KE9yOiUrEIEQ4TbNmyBR/5yEfw5JNP\nHuyh7DUefPBBrF27Fl/60pcO9lAiRHhPILK8I0Q4jHHbbbdhy5Yt2LZtG6699lrk83nceuutME0T\n+XweN954I6ZMmYLrrrsOs2fPxvz58/Fv//ZvWLhwIV5//XVkMhncfvvtaGxsxIQJE7BixQr89Kc/\nRV9fH3bs2IGNGzfiqKOOwg033IBCoYBrr70WW7duRVNTEzRNw4IFCwI9ijOZDL7yla9gYGAAjuPg\nhBNOwJlnnokf/ehHIISgqqoKl1xyCb797W9j48aNyGQyOPPMM3H55Zfj/vvvx9///ncoioKdO3di\n9OjRuPnmm2EYxkGc4QgRDk1EMe8IEQ5zbNmyBXfddRemTp2Kvr4+3HTTTbjrrrtw6aWX4vbbby/b\nf+3atTj33HPx29/+FpMmTcLDDz9cts/KlSvxk5/8BPfddx/uv/9+9Pf348EHH4TjOPjjH/+Ib37z\nm3jmmWfKjnv22WfhOA5+97vf4Q9/+AMSiQRaW1txzjnn4Oyzz8Zll12Gu+66Cw0NDbj77rvxxz/+\nEQ899BDeeOMNAMCyZctw66234r777sO2bdsOSy9DhAjvBiLLO0KEwxzTp0+HoigAgLq6Otxyyy0o\nFAoYHBxEZWVl2f7V1dUYN24cAKClpQV9fX1l+8yePRuapkHTNFRXV6O/vx+rVq3CkUceCQCor6/H\n7Nmzy46bNWsWfvKTn+ALX/gCjjvuOFxwwQVQ1aCNsHTpUuzYsQMvvPACAKBYLGLTpk3ieN4+debM\nmVi7dq3okxwhQgQfEXlHiHCYQ3YrX3PNNfjWt76F+fPn4/HHHxfNPGSUJqSFpb2E7eN5XoCIS0kZ\noL2M//KXv+CVV17BP//5T5x33nl44IEHAvuYponPfe5zOPXUUwOf33///fA8b7fjihAhAkXkNo8Q\n4T2E7u5ujBs3Dq7r4pFHHkGxWNxv5x49ejReeeUVAEBPTw9eeun/t3eHOAoDYRTHHyGYJlwAMAjg\nAFROSC0STCWCIJCYBhwOwxEqegIkuqLBbRN0LQaBxkBZsdkaDJutmeb/05PJ517eZCbz9bYmSRLF\ncazhcKggCOQ4jm63m2q1mh6Ph6SfVv97VJ/nuXa7XdH+z+ez7ve7Xq+X0jTVYDAobX6gSmjeQIUs\nFgvNZjO1Wi3N53MFQaAoikrZezqdKo5j+b6vTqcj13XfGnq329V6vVYYhqrX6zLGqN1uy3VdrVYr\nNRoNLZdLZVkm3/f1fD7leV7xVWq/39dms9HlclGv15MxppTZgarhqRiAj1yvV6VpqvF4rDzPNZlM\ntN1ui3fn/3U4HHQ6nbTf70vZD6gymjeAjzSbTR2Px+J/4tFoVFpwA/gbmjcAAJbhwhoAAJYhvAEA\nsAzhDQCAZQhvAAAsQ3gDAGAZwhsAAMt8AxJ5C+54P8QOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzsvXe8XVWZ///e5dTba3pCQiAJCSWE\nIJGmoSSgjsg4gmCb4Tf+dCwURUdEQXGs41gYFQvDiIyIiKIIJIAgEBJCgJBKertpt59z76m7fv9Y\nu55zboiQBCL783rllXt2WXvttfden6et55Fs27aJECFChAgRIhw1kF/vDkSIECFChAgR/jZE5B0h\nQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8IESJEiBDhKENE3hEiRIgQIcJRhoi8I7yp\nMW3aND796U9Xbf/iF7/ItGnTQsfdcMMNoWOWL1/OBz/4QQB2797NCSec4O3btWsXH/vYx1iwYAEL\nFizgkksu4bHHHgPgpptuYuHChSxcuJCZM2fy9re/3fudy+VC19A0jfvvv/9vvq/Vq1dz1VVXHdSx\nDzzwAF/72tde9bVcvNbz3wi46667+P73v/96dyNChFeE+np3IEKE1xsbN24kl8tRX18PCBJas2ZN\n1XErVqxg/fr1IZIeCZ/97Gd597vfzW233QbAqlWr+PCHP8zDDz/MV77yFe+4+fPn8+1vf5vTTjut\nZjvr16/n/vvv55JLLvmb7umkk07i9ttvP6hjly5dyvnnn/+qr+XitZ7/RsAHPvCB17sLESIcFCLN\nO8KbHm95y1t49NFHvd9LlizhxBNPrDruuuuu4+tf//pBtblp0yZOPvlk7/fJJ5/M4sWLGT169EH3\nq6+vj09+8pO89NJLXHHFFYCwAPz0pz9lwYIFmKbJypUrufTSS1m4cCEXX3wxS5cuBYRV4IILLgDg\n1ltv5atf/Sqf+MQnOO+883jve99LT0+Pd53ly5czffr0qmu98MIL/OM//iMXXHAB73vf++jq6gKg\nu7ubD3/4w1x88cWcf/75fO9736vZ18p7ueqqq1i4cCHz58/njjvu8PatXbuWSy+9lAULFvCBD3zA\nu85I26dNm8b+/fu9893fy5cv5/LLL+fqq6/mM5/5DAD33nsvF110ERdeeCFXXnkle/bsAcC2bb7x\njW8wf/58FixYwC9+8QtvrL74xS8CsH///pD15MknnwTAMAy++MUvsmDBAi644AI++clPVllMIkQ4\n3IjIO8KbHhdddBF//vOfvd8PPvggCxcurHmcbdssWrToFds855xz+PSnP82dd97J1q1bARg1ahSS\nJB10v9rb27nuuus45ZRT+PWvf+1tt22bxYsXoygKX/7yl7nqqqtYtGgRH/3oR7nppptqtrVo0SJu\nuOEGHnvsMdra2rjvvvsA2Lp1Kx0dHYwbNy50rVwux8c//nGuu+46Hn30UT70oQ9x9dVXA/C///u/\nzJ07l4ceeogHHniArq4uLMuq2VcXP/nJTxg/fjyLFi3il7/8Jd/97nfZt28fIISiq6++msWLF3P+\n+edzyy23HHD7gbB+/Xouv/xyvvvd79Lf389Xv/pV7rjjDh555BEmTpzIj3/8YwD+9Kc/sXr1ahYv\nXsx9993HXXfdxerVq0Ntff7zn2f69OksXryYn/3sZ3zuc59jcHCQJUuWsHv3bhYtWsQjjzzC1KlT\nWbly5Sv2LUKEQ4mIvCO86XH66aezefNm+vv7KRaLrFy5knnz5tU89oYbbuA///M/KZfLB2zzO9/5\nDldeeSUPPPAA73znO5k/fz533333Ienv2972Nu/v+++/n4suugiAOXPmeNppJU477TTGjRuHJEnM\nmDHDI85ly5bVvNcXXniBUaNGceaZZwLwzne+k127drF3717a2tpYsmQJzz//PPF4nP/6r/+is7Pz\ngH2+8cYb+dKXvgTAhAkT6OjoYPfu3Wzfvp3BwUHOPfdcQJitb7311hG3vxKSyaR3P21tbbzwwgue\nteO0007zxuepp55iwYIFxGIx6uvreeihh0LWlkKhwPLly/nIRz4CwKRJk5gzZw5PPvkkra2tbN26\nlUcffZRiscg111zD2Wef/Yp9ixDhUCLyeUd400NRFC688EIefvhhWltbOeuss1DV2p/GzJkzmTt3\nLnfccQezZ88esc1EIsFVV13FVVddxdDQEIsWLeLrX/8648ePf80TfXNzs/f3Aw88wJ133kk+n8ey\nLEYqVdDQ0OD9rSgKpmkC8Mwzz3gEFcTQ0BBdXV0hC0Q8HmdgYICPfOQjWJbFV77yFXp6erjyyiv5\n1Kc+dcA+r1mzxtO2ZVmmt7cXy7IYHBwM9U1VVVRVHXH7K6Gpqcn72zRNfvjDH/L4449jmib5fJ7J\nkycDMDg4SGNjo3dsOp0OtTM8PIxt21x++eXetkKhwBlnnMFJJ53EjTfeyK9+9Ss+//nPM3/+fG66\n6aZQexEiHG5E5B0hAnDxxRfzve99j5aWlpo+2yCuvfZaLr30UsaPH19z/8DAAC+//LKntTY2NvK+\n972Pp59+mk2bNh0yLa27u5sbb7yRe++9lxkzZrBjxw4WLFhw0OcbhsGaNWtqCiGdnZ1MmTKF3//+\n9zXP/ehHP8pHP/pRtm/fzr/+678yZ86cA17r+uuv58Mf/jDvf//7kSTJG4OWlhYymQyWZSHLMrqu\n093dPeL28ePHI8uyJ3xks9kRr/nQQw/x+OOPc9ddd9Ha2spvf/tbHnjgAe+6g4OD3rF9fX0kk0nv\nd1tbG4qicN9991FXV1fVtrs6IJPJcMMNN3D77bdz7bXXHnAMIkQ4lIjM5hEiALNnz6anp4fNmzdz\n+umnH/DYzs5OrrzyyhHNuKVSiU9/+tM8/fTT3radO3eyatWqEaPKR4KqquRyuZoa9cDAAOl0milT\npmAYBvfccw8A+Xz+oNpevXo106ZNIx6PV13r5JNPpre3l1WrVgHQ1dXF9ddfj23bfPnLX+aZZ54B\nYOLEibS3tyNJ0gH72t/fz6xZs5AkiT/84Q8Ui0UKhQLHHHMMo0eP5pFHHgHgd7/7HV/+8pdH3A7Q\n0dHBhg0bALjvvvuQ5drTWH9/P+PGjaO1tZXBwUEefvhhb2zmz5/Pgw8+iKZpFAoFrrjiCjZt2hQa\n93PPPZff/OY3ABSLRb7whS+wb98+7rvvPn70ox8BwgoyZcqUgxrvCBEOJSLyjhABkCSJCy64gLe+\n9a0jkkEQ//Iv/4Ku6zX3jR07lp/85CdeVPiFF17Itddeyxe+8IVQBPrBYM6cOfT09HD22Wd72qaL\n6dOnc84557BgwQIuu+wy5s+fzymnnOKtPX8lLF26NOTvDl4rFovxwx/+kFtuuYWLLrqIT3ziEyxc\nuBBJkrj88sv53ve+50W4z549m3nz5h2wr1dffTWf+MQneNe73kWhUOCyyy7jS1/6El1dXfzgBz/g\ntttu48ILL+TPf/4zN998M5Ik1dwOwvJx88038+53v5tUKuUt8avEO9/5TjKZDBdccAGf+cxnuOaa\na9i/fz/f/OY3ufjiiznrrLO48MILec973sN73/teTj311ND5N998MytWrGDhwoW85z3vYcKECYwZ\nM4bzzjuPdevWceGFF3LRRRexZcsW/vmf//mgxjxChEMFKarnHSFChAgRIhxdiDTvCBEiRIgQ4ShD\nRN4RIkSIECHCUYaIvCNEiBAhQoSjDBF5R4gQIUKECEcZIvKOECFChAgRjjIcNUlaenuHD2l7LS1p\nBgcLh7TNNyOicXztiMbwtSMaw0ODaBxfOw71GHZ0NNTc/qbVvFVVeb278HeBaBxfO6IxfO2IxvDQ\nIBrH144jNYZvWvKOECFChAgRjlZE5B0hQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8I\nESJEiBDhKENE3hEiRIgQIcJRhoi8I0SIECFChKMMEXlHiBAhQoQIRxkOK3lv2rSJ888/n7vuuqtq\n39KlS3nve9/LZZddxo9+9KPD2Y0IESJEiBDh7wqHjbwLhQK33HIL8+bNq7n/a1/7Grfeeit33303\nzzzzDFu2bDlcXYkQIUKECBH+rnDYyDsej/Pzn/+czs7Oqn1dXV00NTUxZswYZFnm3HPPZdmyZYer\nKxEivGmhGxZL1+6jWDZe76542NuXZ822/te7G0cNXtjYy879wyxduw/Lsl/v7rxq9GWKrN8x8Hp3\nA4D9AwVWbekDoKyZPPdyN7Y98tjmSzovbOw54DFHGoetMImqqqhq7eZ7e3tpbW31fre2ttLV1XXA\n9lpa0oc8Z+xICd8j/G2IxvG143CN4d2PbOTXizdw3twc11x+6mG5xt+Kf/nm4wDc/+13oSiHTn/4\ne3wP9/Tm+NEf1ni/48k4F8075rBe83CNo/vcf3XzQpobEoflGn9rX+79+jv4+d0vsmzNPmRV4aK3\nTq55/I9/8SzPv9zNdVecytvnTHjF9o/Eu3jUVBU71JVuOjoaDnmlsjcjonF87TicY7hhu9BwN+wY\neMM9p737syTjh2YK+nt9D7dWaKobt/dz2tS2w3a9IzGOXXsz6K3pw3qNg0V3zzArN/YAsGnnAKcd\n117zuA3Oc3hh/X5mTWw+YJuHegzfUFXFOjs76evr8353d3fXNK9HiBDhtcE180lIr3NPqqEZ1uvd\nhTc8SroZ+m2aR/+YvZFcOJZtY5jiG1EPYAVqrheWgsHh8hHp18HgdSHv8ePHk8vl2L17N4Zh8MQT\nT3DmmWe+Hl2JEOHvGq6LTnrjcTdGRN6viHIFeRtHsc/bRb6kv95d8GBaticQqcrIH0mLY+bP5N44\n5H3YzOZr167lW9/6Fnv27EFVVRYvXsz8+fMZP348F1xwATfffDOf+cxnALj44ouZPLm2ryFChAiv\nHW9E8tYj8n5FaHp4jEzz6CfvQun11byDQWeWZeP+UuSRddn6VAyAzAE072x5iKZE4yHp48HgsJH3\nrFmz+NWvfjXi/rlz53LPPfccrstHiPCGwf6BAo3pOOmk+Nx6MkXSCdWbEGqhe6BAQzpGOukf0z1Y\noLk+QSJWHbiZzZUxLZvWxmRou+Wazd+A7H0kzOYDQyUUWaKp/rUHSFm2TVd3jgmj6pEliZ7BAk11\nCRLx8PMoayZ9QyXGtde9pusVSjq7e3OhbYPDJbJ5jaa6uLetN1MkGVdoSMcrm6BYNtiyJ8u49rqq\ndwOEANWXLTKmrbqvA0Ml4jGF/mzJu+dK2LZNV0+Ose11ntnZtm329OUZ116H5IxTXeBdz5cM9vbl\n6WxJeedYts2W3VniMZljRjfSkynSmI6FYiJ2dQ8zpq2OmFqbZGudUwslzbdmmJZV9bdl2WzenSGV\nUEknVOJxxfuOhgo6fZkidakYqYR/naV7V/B/G+7lIye8n4s7zjng9Q8VjpqAtQgRjkaUNZMbfvYs\njXVxvv+pswD499uWIQG3//v8mufohsXNd6xg9nHtfPQfZgLQny1x48+X8455k7jk7ClV51z7388A\n8D8VbbpKhvw6cLdpmWzL7mBq85SawsOR0Lw/++OlQPW4vBosfm4X9z6xlcvPO45Tj2vn33/6LBOm\nZ7A7N3LdnH+jOdEEwDf/70V2dg/z7Y/No7059aqvd/MdK+jLlkLbNuzKcO2tS0L3c8Mf7sYup/nF\nxy6vauOexzezZNdKmuoVvnvlZVX7n1i5h3v+spmvXHU64zvqve2mZXljB3DlBcdz3pzxVeev2TbA\n9+9dxZknjuaqd5wAwOMv7uH/Ht3E+88/jtOmdfLvP32W9iZfcFi5uZdfLd7I208dxwcvnAbAqi19\n3HqfiKq//v2z+c7dK5k+sZnPXSFWSGzcNci3fr2SudM7+fgls6r6kc2V+ffbljF1XBM3fHBOjdH0\nEdT8g0vvXFJ/YVMvP7l/rbddSuZomLUSuXkaVqaTz922jHHtddzy/73FO+avu5cAsKJ7JRefeGTI\nO0qPGiHCYYRmiAlhKK8BYDj+tQMZP4tlg7Juhvxre/pymJb9N/vcfBPhkWfvezf/ie+v/CkrulfW\n3K8bZs3tbyT8cevDfGHJLWimxsrNIsh21ZY+9vYXIFair/FZ+kuDdA3v8c7Z2S0ijQcOMrhJt2qb\nkSuJuxZ2De0hPmkDieNfrLm/qzdH4riXKI15ofY1MkVsoKsnrOFXmrbXjrAuf9veLADPrNnvbXtx\nUy8AK17uYSivIaWHyE+7H7mpx2lLRG4/8aI/Zv2Be3VzAGzYlfHb3LsRKV5gxYaemv3oHiwCsGVP\ntuZ+0zIxLfG+BX3uZoC8NSe+oC9bDJ0rN/Wjy3kxxoo4d09fPnRM2RDPOqkcuSVwEXlHiHAYURlf\ndDDapghSssnGtzGsiUm1NyMmt6DPc1XvWnoLB0524pL366F5P71HJF7al+/2tlkBf+PR4PN+ZOcT\nDGnD9BbD45zJlVEa/WVcRaOaaJUDBEC52DW8m2v+egNL9jxb+wDJIjZlNXJzd2iz+1yf6FpywPZ7\nC/6qHsuuHm83IK43EyasSvIeyVRdy6Li3rdhWpiWTWzsVtHGpA0j9jN4vf4KoaW/OMCSwu+Jz3hu\nxPMrz6nEf734E7723HeBcLR7Tisi1WWIH/8Cw8ZwVV8AJFXz/pbragsHZVMck1CqXReHCxF5R4hw\nGFG5tEc/iKU+Zd1Ead9LpvU5/mfdrwF/cnU1hf35Hn625k7+47n/8jQGoCoDl6d315hkNw5sYU3f\n+oO+l1eLxri/TjVI2G/0pWLuhAygW+EI6d5MESnuE0ZBD5MfgHIQEtOjO/8KwIPbH625X2ndj9q+\nl8TxYeuFu7xpW3YnALZR7QEtaQZFxSfvWn0cibzzAQKT6rIMpNbXJH+5xj2qTuCXYdqifcVpq0Yf\na12vp6Ivq3qFCVtOjEzQwf6XtDD5dg3vYcfQLnoKfWim7l9L0fn+y98mOfNZlOZetiUfreoLgBQr\nB/7WqIWyeeSj0CPyjvC64I2UZvBwwqwgU10/OPKWG4RWtze3D/AnJ3epUG9RTMq6pYcmm0rh4EBW\n8x++9DNuW/2/r9ifkfBiz2q+8dz3KRrVpBCc6IPEFyTv16p5W7bFrSt/7hFg9X4bpW0vcsv+mvtf\nCbuGdnt/l4zw5NybKSLFAuRtVCeRMi2bIW2Y32/+MzktX7U/eI2JDeOq+g4gN4nnbNvhB1jWTWzb\nJlN2TMuKUUWufZkScr2vKeZr9NGNZu/LhImx4JiWpbosyZnL2Bd/gd3De6vOryWfuEuuTMuirJtI\nqmjLNqsDNF3BsxAwZe9zTNKphAgEfMkhb9saWRjakFvtPefKe1m+z3cZ5PScZzavJGJNyZLT86G+\nACEhDTV8zu83/5lfrL0LzXnHa1lgDhci8n6TwPUvWrZdMYGaNY97pW2vBat61/LJJz7PtuyOmvst\n2+Kl3rUU9dIBr23VEAAOVV9/uf433LT0m6/6fLcfQfI2TCtEriMJMJpmIsUFIbYmW4Cg2VycP1jy\n/YHByaaSED2zeeC3bdvkdJ9MKv3ouiGIwbKtmtqWi9vX3sXu3F5W9673zgHxXILm/JAGG+hfppzh\nzvX30F3oxbQsBofL2LZd9QxFX6rHqrfYz4bBzdy/9aGq4yzbxjBMYsesIzaxtrm2ss18SQ+ZTHcN\n++RdOSn3ZkpIcX/cCjUEGNO0+c7z/81fup7imb3Lvevphskftz7MZ578Mn0lIaTJkh+xXiwb7O4f\nRG7sQ2l0yLscDnzTdJMhLYdhi/5KEhR1v49lzWT7/iGkhE/Y+RoCxLDdR+yYtfQMD4W254o6iZlL\nSc70a04MlAZDxzyzdzm7zZeRm3tInb6IPbl92LbtRZAXk7tZMfCMr3lXQB2zleuXfImCXgwJoK5F\nJp2IYds2e/Ou8CXh2pJ0w2JgqESuqDNYyrIz/gyJ414C/JgDwzK4e+PveaFntX9fWt5/xnJ1v7Kl\noWqzeby25m1ZNn/peoqVgfaPJHlH0eZvAnQPFvjCT5/l4jMm8fLOQbbvG+J//n0+Dy7bwX1PbuPm\nf57LxFEN/PWlPdy5aCPXv382MyYJ0vjNXzbzyIouvvWxeXS8hsjZIO7fIibbv3Y9w5SmY6r2L937\nHHdv/D2N+jF0r5zOj649J7QsA2BTV4Zv/t+LfPySWcydLrLzPfp8F3c/tpkbPjgHPdVNS6KZ0XWv\nLnPfc/tFAJBhGajy3/aZvLxzkO/cvZIPLZzGceP9VIr5khEycRumTUyt1iZKuomUEGSwfVeZNW39\nXhCNaxbvCfgyQ5p3gBwf2LaY4eQQ0IYkSZR1k49/90nOOGEU557lB9Zc96On+MAFM5h/6niG8hrX\n3LqEc04eiz3xRV7u38R/nHUjsQOMgaZb/P//+SRnnTiGf3nHDD73k6VkpN0kRCBxyKToE7PFnwZv\nB6A50cSG5Z1s2JUhEVcoayafuewUZk5u5cWe1WzdrPDw091V0dvd+XDw0n/d8xLb9g3xb+85ke/+\n5iWuuGgikmKCbGFaJorsE+SqLX384Herue59JzNrShvb9w1xyy+fB+CbH5tHZ3OKVV27vON/8dBq\nxqnT/WsnXkJp9f3QtUzSmfKgR3j3P72de34Dn79iNt/69UpSpz8ROvalbfvZ2TlMc32cz922DOnY\n5SSm+wKQVKHxLVu3n98/v5LkTH9btpynLp6mpBlc/+Ol5EsGiVk+mQxr1Zp3X/ol1NR+8oqBbvhR\n0kPlAnJdmNAHy2F/76833AdAfLLQqH+y5EFSPafQ0SKeUXncc7yUAzlZ+x5iEzZj2LBpcAu5cnXf\nDNPi9kWrKDrmckm2QBZC4pduX06PE6TWcEwXeJ+5ze0PvsykUQ3stzZXxRIM6znyJdFfqYZQ8Z3f\nPUeb7FtB5MZ+5PQw2BJIdugeilr1+bWEuMOFSPN+E2CjE7X50LM72b5PfJCWbXPfk9sAPzr0waXC\nf7Z07T7v3EdWiIIxm7p8Te+1wtXmRlp77EbuZhFmuv6hamn2ryvFMfc+4ZeSve+vIjBm+cbd/PdL\nv+CW5f/5mvsa1BoPFktWi34/tGxnyOddKOkhzbsye5aLkmYguf492eLhZ3d6y1hcTb7HMZvHlXhI\nU3DJ0bAMFu34CwPNKwDxvIediPdn13fTlfMjfVEML3p2l6O1PLVqN893v0TeKJAp1Q7ScbEvI/Yv\nWbMPy7YZGCqHTI1lI0jezrNP+pO1bulelHDZuc+la/ezYWAzt6+9i8cHBUm8vCus+e2vIO91OwYp\nlk1+/dgG1NHbWbRGmFslSZivg1i0XBDzn5buAMS6eq/d/gLL973Attxmb5tml0NLBOzOTeJ/UwgE\ntczmPaVe729TEtaRxc+NUIBJ0dm0O0NXTw7dsFCawgFykmqA5L87f1ixKqQVA2RLeedehCY7arRN\nLOW/vzm9uo+u8UFp2093xo84HypWa+lBzTtoNZJi4t4yWYOd3TnvGVYimbb454unM2/mKN4+2yfI\nn6/9FXvG/L7KvVHWTZZt3hHaJsU0hgs6PYNFL9+BlvYtJA2NYoy27xtClqvzIeS0PANDzvtYg7yF\nZu5bshLTxfcjmXHn+v67nCtWZ4orRWbzCIcSlZGiUrzAQCHrmbfcCdUNsKn000LtwJRXCzenkSzV\nfv08U63j56t15ZST8CS47MM1t9nqa5N+g6biVxOIogdyJQfHslAyQj5vbQTyHtZySJK7QNsMCS8e\neRd8Yqg1BtlymKzcyF8Xe3P+RCkpJprmulWcyzb4wlqmXE3ewTEKEpebgSpE3gEBSDMsUDXkQKT2\nYDHnBWC5SMQVT4iT04JUKpPT7A1EsQ+X/D70JlcTm7iR4lh/nfJAKSx8uglz+lIr+dnqX4aEqoHC\nEHe+fA92zH+PJMW3mtiYge2Oz7aW5q0NBI4TRDE40lI/xaA3U/RiG1yhIISA1hcb5QsBVkEEBA47\nhNubKSI39DM0cTGmFCDvcjUhm5Lfn64B35ozXK6+n6DmXWt5m6aJ96w0AnnbisbZJ43lX981k7NO\nGlO1X2kJC2NlKUd8imOSdueCWJk9TuKaGcc0ITf1hPz6l54nhILebBHd9L8LN2ZgWM95Yyyp1fcg\nqToZR8gNCktSsQXbkpDrhjyXVqYQHs/GeEOkeUc4tIhVJNxPnvIUNy3/ukfq7oTvEnStmsG1siu9\nWnjBOCO8fqZDDF6QTsW1C3qRVervUNp3UyxXTxSGUjs46GAR/AC1V6F5uzm7K8k7XzIOSvMeChCv\npBj0Z/0J1jQtbNv2JlLN1MgXq33Kg+UwWelGONZhMKhNK4bXF89H3uATT7YGebtL2AAKZoA43Ykx\noKGEzeYWyROXED/Gj3LPFKsrMCVisqfpuYFKlQUt9hd88t7W7U/87uQaxEDRHw/TMulveB65foBy\n0xZW9a1DM/yJfqhUYwJWDIYdTasQ8C3bloRqJ8g774zhPV+bQS2gPTvrg0dapy8pOn2ZkhfbYBuB\n4C6XuFS/j5Ll77cKIrmKaxbvzRRDhKZaooJXvkLz3jS4FSvhH7c34z/zXMDEbmbbkGyZTDDOooal\nwRVkhgvV38yY5Hh0S/e+p2DSFu/8eFhrjU3Y4AluFB33k6qxs1tsax6TJTFNuLdkW4yHkhTj25sp\nUTQDz0lLOPeV9yPTlWrNWVI133IQ0MylvTPBjCHFyyROehqAgYL/DZzVeiGtyRaKRumIBeNG5P1m\nQ0CajKsVmrdyZDXvkczmXiCRM2lVLrfaNbybItmQ9haEJudqbg9ix9CuUDRxELkAMb0as7k7gcdU\nudpsHiDQkTSUIT3Qf8UMBVaZlk1eL2AENJ+hgJbktl+pLZtWOFguUwoICLJBWQ8njwlOpBkt7PuE\nsHBQNH1hyU0sEgzyqQxYq4zyHa6hESZiCrudSHsMYbKsDCQKmnF39PmWCKxqrTWoea8f2EivuoHE\nCc+BLO47GImdr0HekmJ4VoVcYLy1DXNRiFN0iNHVzmPHrmJDYZV/vqPlDeVqvE+2BIpBT7ZAr5sg\nJBCZHdNF/Elw3EzE3+ZQC9aQKBE6rPmad5D8k5YgviB59xT6+MHKn3r3D9CdC0SmOwKKOdSKtuUU\nVCsVGsOagVmOcDGU10LzjLZsY2voAAAgAElEQVRtFi0J0Qc3ULJWamC5gryDEfbGUKM3Bjv2i/dR\nSfnHj5VEPIIuFVBkib5MMWzCdsZzqJwj4zyDWj5vggKSs9/oHYdeSHrjLzlj1p0V42V0T6TdmEZK\nTWLaZkjjP5yIyPvvEAOlQTQzaEoNkETg5Yx55C32u2ZzzSpXSemHMsmHa3IdSZu3bLe/Yn/l8ifX\nZOx+XJU+tqLtE9NI0dI/X/Mrfvly7dz6w4Go3FdjNvfIW5FCVox8yXCehY0yagd7ctVLbwAKhk/e\nUkVErGnZVcQ8FCC/2uQttO6gmX446ANWzCrNO6g51zKbBzX3vOFf39e8S9imgm0qlMxqn3cQtZaa\nxWOyPz6KDtgVS+L00Du6dzAgyFnV01qwv7XKoxYD1oOc5vfX1WpRDE+wcTXvuvxUrFwrshX3rDWu\nEKS2+W4J28bT8mwIERtA0m5Ckm36snl6B4uOUO2/N0nTIe9gwJfzHWtbT8Z2hJu8JvrQmy2FiClh\nC7N60KKU06sF3IGCP0Yll7wHRoMZQzHTDGnDXpayWm4C1zIwVNA9rdUcGIXZN576mMidviWzXRxb\n49sXAl8wsMA/xsoJ8pcSBU/zjiXEOOp7jmVS8ngAslqWtqYkvZli2IK2XUT2DRQDgmhgjLRts5x7\nEGOcSqj+flMV30dIKLTodSL0bSNGb6ZIWhWBevkaY3M4EEWbH4XY1T3ML/68nk9eeiKdLeGi9nm9\nwJeWfoPx9WP5wunX8H+PbOIvL/oaZnACiFVq3g5Db2m+l889bdO89VLv2B/9YS3nzxnPFRccf1B9\nfOCZ7by8c5Dr3z/b+1DveXwz2ZxGLqWBAiDxg3tXsdkpSPCpfzyJyWMasQhr3pWlI7tdf6/zcXUP\nFkKBa3nTn4SeXtPFI8v3ceOHTvMi1otGkUw5S4MzET350h4ef3EPLQ0Jpk9soXNyteb98PKdrHi5\nhxs/dNorWiFcYUOpMJs//uJu9vUXkFI54pM2cHfXBo4ZfQ0dHdNC5xfMvC9WKz7hqopMMbGXb6y4\nN3R8vkLzfu7lbh7ZvAncVNWySV+2xH/+5iVQyySmr6Bo+WSlxk1PAHIzuAXJ+4k1W2nMdHHh3An8\n/qmtDA6VGTfL13SDxPenZ3aI8+MlbC2JpOrsGxxC003iMSUsSOL4CZ3+j++oZ3dvDqV1L08Ul1G2\nXQ1JRBkPFzRu+eXzDAyVuPyiseLWJBnLtugeGgScSHTJH3NbjyPFNJ7esI1ZiX5mTWmrGVQUFCCW\nb9hLYgZItoK2+VSSJz8VIkPN0kgACScVpmzF0S2DsqlVuUJsSwJLDWt5FRpfPpNAaRVBcXv6ZEa1\npukPHJM0WxlmK/Gpqyi91Iytpfz2jJj4B+zoHeCWX66gN1MiMcrC/WpiOLWoy4M8u24/v35sM23j\nstDqdlJEUvfkMlzzvb/S2ZRkz2CWeDNgim9GNtPY2GTKQ7SlWnjmZT8S37+voNbqru0W53em2wG4\nc/09nNwxq3a8i2yKNtzgMMdaUVpzJraewDZU1NE76Fk3FmjwrmFmOmlKNEFJBLt2NI9l3fYBVm7d\nBzEorT4Lu1SHjEJ/UVhrmuvj5J0xHNf/DrYMlGHKWi/4rqkuTlmvuIf++fR0OMl0VIP+vAZpMUZ9\nmRId44UrIK8XSODniT9ciDTvoxD//fs17O7Nc/+S7VX7smUhDe52tJYgcYP/UUE1eXuk5Ex++/rD\n2vdjL4i2Hti6iJuWfjNkuq3EH57ezoZdmdBktvi5Lp5d3+1V7zEsg1Vb+ymUDTI5jXXbhfZkOaTq\nkXeFGd9dJuVOYNv2DbFuh29CzZm+VnnnY2vZ11/w2ga8NchlS5DDLxdtpKsnx+qt/fz2iS0hs7mr\nzdz7xFZ27B9mYFhM/I/tepIvLf1GTbO6YYj+xlQ51Hd3PIPEuLXGWveiJTRZ24gJE51kkUooJOMK\nQx3LveMkXZBVIUBGmmFy2x/XMRQ0dct+pLrascf3IzqIxy3vOXmacUwTpk5bwlSKvLBR+JT/vHQn\nz6zd7yWPUWWVklXh/5QNpJiOrSWxTRXd0ti0W5hcS3p4vNpTbRhogM2oVnE/8amrKdhhbV+KldnX\nX2D7viGyeY2N+4VmO8FJbtJfEM9/zrQOkulAycfhZmxLwlIK/OB3Ivgp6Au1ikIjdCO1g+M1Sj8Z\nu5wS5tsa5OvmsVY1oRWu69/gCUFuxrPyunnYhhr2V1eQt2vilhQD07LpaEqScuSQs8fOo8mY6B3b\n0C76Kak6tiWDrXiad1e+i+37hsgVdU8rbUk0M8Y6EXO4ha58F89u3k6uqLN70DeB1xtCECrbBbbu\nzrJsXbfXx6ljBOm675rrLtm63/fne/0P3CMO8br7zh53BhMbxmNjM1QerhKgZEMoIak5j3uBeZKq\nYVsSdrEejDj67uORZBu5LksqoVK2xHc0a2In582ayrFNk9kwuJlp08XzH8yL91wkh5EYk5jAgN6L\nlMgzujXtPceGRJpPvWc2tiV5yk1bU9IXnB3yTlvtTFBO8PqWLfrfaaGkM7vzJE5sn0FnXTtHAhF5\nH4UYcgJC6pPVfiP7gCUvCJnN3UxIru9VqTRlSbVNzot2Pk5faaAqorkWatbudYSDsiHuo9NZu+ul\nAK2INh9Z8xb34i6BclG2AxOxc7/BW3Ozk2mmVm1Wl02yAZPycCk8ybhc/IctDzJQGqyZdcowLZAN\nCsmuqr7Hj3+exPTnvd+1AuLKtiBDu5T2+pSIKSiyhGwEAn1KQrovB8hINyykZA65MbBGOKC9h9Jo\n6qItJRYgb9MCbKRYmeZEI7aeQIqXqtJn7sntJ6HEmdQwAc0uhd6Vjsni2tZwC5gKyCb5ongPKrN8\n1cXS4n1QjFCZSxcJSbwbUkwjmw8kRTHFxDyrbTqqpNDPDuIxiX+7ZBbHT/K1HqtUJywA8ZJnBXH9\ntbGhiRyfOA2A/ny1sCNZCiCBWaE5O/uTagJFlogPC3J9dt/zvrAqmyT0Nuxio/C31iB/M9sqtErX\nv+1sb29OolllpjQdw+XT30NCrqO8+RQAzjhFVC5D1T2N2y6lMTPtKI0DyM3i25Ad8rzm1I+RUJKY\nvULI2Ws6Firn2zH6RzPJnOe0GXgXnb5ceubxpBMqaOJdcZMDlZx3Ttt6EqUX345VrAuR97hRTh4B\nh/iS8RjHNYtqeHkjXxWVPaV9lD+8DQNCaFUMMGOeddF2+iCpGsm44rXxrxefQjKhcsGkc0UDDb1M\n7KzHkp3+OO/8xLiwcClt+xjVmvb6m5QTzD6uQ5i9nW1j2tJV1gNVkZkxfpTTB92L1bDNGLppM731\nOD520j8TV0Yu9XsoEZH3UQg3pWFDuka6wVcIsAp+YO5yD9eXqCgyECAb+cDZygyrOjBDMzVRLEEO\ntx3KmuWQd8kh77FO3WOXICqXihkBn7duGV6gkrf8JlS9yaZsB5f4iD4G/es9gexfwSUvUqJA6rRH\nWbTjL962XLmCvKtyh9fI8mZaxI9byZ66p9haeDm0T2nuC/2u5VPXKGDbYDlZtSTFEOStSEiaX3fZ\ndLRGzfKfeX+5j8TMZUiq4S83CvrNAwFKdllMikrM94drugmqjiTbNMQbsEsppHiJTD6Q7U6y6Cn2\nMrZuNC1JQSZBa4LWuB3bkjB6JmBbigjGGhSknXeimG1TYczwOYK8Ee9lMmVWvXMtsrOkKFYmGxDS\nipYg77H1Yzix/QSM2BAtHWUkSfKIxb1HW0tCrOwJGG4msobisXTUif5nQwF8jvbs+DhtUw2Rr/ve\nJeQE8ZiCVaynM93OjuwuQd6ShSTbWIZ/vhCgrND5Vq4Fu9jgkYvSIN7rlqYYNjZJ1dHsFckjLtcq\nIyl6IChNwth/DBAonOFcI6UmUWQJa0jYyPP0e+MNYPaOp0FtAFsKPUM1bnrnx2Iylkve5QyWbVGS\nRTu2HgdkQdJObAKAGjP8sUMQn/usl+xZ7uVkd3FSu59tJnHcS6RmPYuk6khmjNaGROBaQEwnHlMo\nODEPKUX0rTMlNN5sOYuqytiyLtwWtqC5MbFjkWwFdew28vWbkRQD25KIqWIcG2PNSIkCclOvqG8e\n8Hm7z8H1a8tNfdhjnRUThloVVHskEJH3UYxaUeFBM26tJQth8hbHFsoOwcmSZ+6CEaIxAyjVIJ4l\ne5fzu81/Iu6UKHQTHlQm+we8oLr6VIzm+rgXqeznwq4OWOst9PmEqRiA7ZVelOoyJE56GjsogDj3\nIAX81K7mDWHylOurE9G4ZnMXlZp0Lf+pYfpJNoaMAye3qWV216WiiLB2J+eA5m1LgQxteYe8bf8e\n9pZ3ICkmetfxGN2TgLDmHXz+linGRFZNz/pSMjSSJz0FQL1aj1VOIUli+ZUXSZ7MY9kWY+pGe+lb\ng9HphprHLtWBkRBaqwQ9Q4Jsi6azpGr/MSQK47wJXU6UeEL/XxLTnwtZB1plx7edyntCK0DRsa40\nJxqZ2ijiMFIteeceAuRddDRvyRcw3ICidCzFqAZB3iUr8Jwd8rZ0550xYk6kcfC9EwlyknGFsm7S\nnmwjbxTIlYre+YYhuwMi/ne/rQpSsDVBCLGJGyFWoqlRXNclJUWWQBcEVrByoh+q7hEj+OlTvWVy\nAdO+KsvYWgoZGSvmm91BmHx1A5JyGjk9jNK2l8TMZ5AdQSKlJokpMmbJ1byz/HHrwxjNwuftWg2E\ni8dGaXfW5scCPnkH7rNetm8Fd738W4KYN2Yu7516iX8/ySFQdFRJCEiiLdcXrpGIyRSMIkkl4WXO\na3LqqWfKQyKHhaO5e5kiTJVYYTSSbLFOe1ospzNVYoo4/8KxFyFJoI7ewdi2tDf/ueMcU2XqnMC7\n2Litfl/N2EEVHDrUiMj7KEN/IUPylCeQW/bXXCccJCOj1gtVg7xdYrVtO+QTr6V5u/5qqC7WAHjL\nJJTGAZANr22fvG1PA3LN5om4Qkdziv6hEoZp+ZHyjoYeLIPZEyBeSRZtuZp3/NhVyE7mLjcgxtMw\nAm0EyTtoqQhOhi4KevgeNcMK+fprJWUIErxsB9s8sLAFIg7AUHJYpTS25ZyrGCTiCrIsYztadHnT\nqZ42plt+H12t08o3CpO1c76LEHk7yT0kxcS0bAzTIqsNeoFCti152rmUKLK/wmffnGyixZkw3XSu\nSBaWpHt+WDdCtzebc8bL0byNGLph0eFoS7HxIpuZXJ8N9bfDnoptg9wUWAoGDNvid0uiBVsT10qk\nxHlFowSWjLZtltBuXXOrI2C4qTjr4knGNAvhQx29E6VjlzceAGXNyXtQSiPJlne+uz+pCGIp6xYt\nSeH3HigP+uTtnO8SnEsGleZYs38Mck6YY+Vkgfp6cZ6vecvYegJsWJ9ZizJqJ5IEiuW7GWwthW0H\nnoOsE1fiKLLiLAGVSMuNSIkCx09o8oPLzBjDeY3ZTWcgqQbxY1cj1w1jJ7NOH5IidqMk+rJjaCeP\n7XrSfxCGK4CIMY5PWRsao+A35Uac10JKTXJK58zQNkm2ScpJLzmPa2lQO/YwNGoJBb1ISvXT5SbV\nBEklSaacJaZIQrM2VRJxcb5hWpT2jQ9dwzZjnvvw2JaJIj4hVqalIeG9h+51Fdm3HoTa0OO159rD\njIi8jzI8vnOZSBRw3Es10xAGyeBHf1hbtT84eWu2+LusmTy8fKfIchUMOlGq28+X/PaD5kkXwQpS\nUsqv4OOlHJRsQbrgVeJJxAR52zbc9sd1dGcdE6ZD8oWSwc/+tI7t+4ZYvlVIvHaAmAbcDGSBicI1\nobkfYFk3uOuRDdz57OPszPjpX3/1WKAkZg0ff7DYA8AdD7/MMxt2+PsDWt7zG3q4/+ltXoY1gBUv\nB7JG1bBkDBULWJbN/zz0Mt/9zUrW7u0CyRZBOs49SorQvFVZwpZ1UnIdVqbTm1SCfmQ3iMc2Yz75\nB4Uwl0D2noDpaeZim6abFAMa6JT6qb5Glyiw28ls5b5DxbzM48sdM6yreTt+U9fE6fZxy0AXS9fu\n8zVcI0ZXT47lS2JYxXqkej/gUJIgZTfz2TmfADOOlWsRVhGnbbmhnyFpPzNaj6cp0YBWEtey4jl+\nt/lPDJYzKHoDZt94QAqR992PbWbzfiG8NSTqGN/a4l03Ptl5FxzXgusxsYvChy4lnVgKR4BKqkkS\nMZmybnoWiKyWDWhsYdLxfMoBUhjVmgYkFIe8pXiJtMNHSVX0W5Yc06/zPONOXeykEqg1YMvCwuCQ\ntyXrYc0dyA6oSDGd3WN+i5zyg62GCjpnj52HbVbTQVJJEFcV8jlQibNreE9ovyuY6Hum+hslv5JY\ncL16LeKb0Xo8lxx7MZIk0ZRoqNqfjqX9zHqm6iXsKSf3M1jOkI6F6y00JxrZm9/PQMOLYi4zVS+W\n4p7Ht1AeaGFy9p3e8dZwi5dpsi6pYhtxJFVHVWTf8uhp3lLoHqxcE+WNp4KexDBtNnVl+PkD6w95\nIaeREJH3UQYvwtGSR9C8fXJdvbWvan9wOYcZINpHV3RhmFY4KrZm1R0/GKyW5h3URCXVr9Lk19AN\nrNX1yFtm1mThk3txU6+fwcohnadW7eXZ9d3c8svneXHnDsDRLBHLSTzLgeFrIp3pDq8PAC9s7OXJ\nrmdZXlgUShm5bmcgM1egb66/2fXLu9jTm+dXj/vJN4JLjH58/1r+9MyOUCYwSxJ/nzdnfDga10Hf\ncJ6+TJElq/exbscgK7YJ4cQq1vtai2z4ZnPZICY596m7+Zb9PnqWF1MNkH+15l3YO9Zr31svr1ue\nyVnfNY2J6SkhzXt3r/PsHRLq2qvRtdsJ7nK1UvcenWdh9ApNRx23lb+8sMeL5G9K1lPWTdZvz2L2\nj64al3pzFJObJqEbFtZwC5IEckoIdW5ynrdPOFvcS05MY7uNjSLeApDxScMn7yKPPt/FcKmIbcP0\n8e00pfzJ2DYVLjhtgvfeubGK9YrQqpV0nqb6uDdeKVVohZpmehaIgVLGF5Zcn7kTeCincsSOXYXa\nIVZttNbVMXl0g9NH8b6NHS15edBd8vVQIfx1NjaGfttaCjlRIjZlFSYaKZf8HfK2HdO7jS0sHDbU\nxVNccf5xjG2vJ2ZWL29SZMVZlSJR6vOjqMubT0HvOs57zu3pZibEnWWkqk5RFXPP5I4OLxVqXQ3N\n+x+mLOSCSW8T/ayxfGx0U6OnOYPvv3aRVivJ2zGdpzYiyRa2odJQkRBmzqQp3t9mpsPLQJlMqCTk\nJHJcFwKPa4Gq4bcHMHomYGU7aUzHMEyLp1ftZdm6/fRnj0x+84i8jzJ4ZGHEvIQQQYQCoCo0ye99\n8kzGjvJfZMM2mDymkUmjGiiUDQzDqliPWi0cZEv+MqNaPu9g6khJMTzy9uoDy7XIW+GMmaOZM80h\nXOe6biajYPUeOT0szLmOyTc4oU0d3eH9PaqCvLsHi6GsX36DZu2/HXPgLv1lnt+/MnxOILCndi7j\ngHnc6d+0Cc186rJpVUdqZjkkhA3oTgnIYr0n8UuqCNBRFAkUHQXXzxj0AYoJzrUE2IbqTToN9YHP\nXNHF0idLEe3bEpYsyLikGb7mbsQo66YnxMjJgle8xBUWTC0WIka3L+75AP/90X9gUuMElPpBhu0e\n9sbEWH78nbO5/v2zAbDyTVXjIluOS8AwPdLxVg441291TNXZrF1V79pSfWuEbz1wBQwDhRjzZo5B\nkiTU3aeKZUKKiT12LRPGiOuVyqKm9LX/cBYA8+c1M3/2uEAwWIJETMEGGmMueQ/6AW+u5u1o7kpr\nN2rbPuQ6IYTc8E9v83IPWI5PefrxKe+7cjVvL6eMHo7GnzVugvf3tz8+jwmdghzV9n0YUtkjb1fz\ndp+Vi3Qsxa1Xn8Ox45qIqTIzx4r2bEOl05jBW8ecDvhLSs0+EX9QrzZgDY5G6fdzPnz742+lNS2+\nydiY7QzYuzm+ZSpffN+5/MvFM4BqzbshVs/ExrAZ28qFBZJxLS2hnPZSxZyUrqHNB2FrqVBFwpnH\ntPD2U8czPjlZXG+oDdW5P1mSOG5MBzYW31/zQz/4L0DeDTFfwDH7x5KMKzTVJzBMS9R4l6Ct6dBU\nX3wlROR9lMElC3dyrUTIh1rxosdUORSJbdg6MVUmnVTRdKeggHJgzXsoQN7lGpp3KFtWgLx9zdvv\nk0fejmTtfaTudR3ydgPzpGQOuW4IK9vmE1egv2ogAdKYOmfpiUMm/dlSyKzuJVEIEHZQcLED2ZTu\nWH936B6DwVmVZnWlfTdKu798zG1TUYTJOwjbktCscMrUYVMEuNmlOo+0pFiZZFxBloVA4+ZxxlKw\nLRkppnnpJstWwIXg3INhB56pE8ErGEEiIdWhSeKZarqFZrmFMWLifdATyCjIyYDP2yFRvaSCGcM2\nFc9c6xKrazaPqTJtyRaQID9KVGhS+o9lcvNEjxRcK0pojB0TsW5YHmkpDYNI8aInILgTaX8mXPEL\nwFQC5K255F1EqsuKwCzbJ8J0cRJG9zEAPLN/GT3SJm98VUVmVLodWZJZ2buGPnmrFxOQiiW9dzet\nOMVB9Kz/jjvjbzlL+jwyQBBZc6LJS0lslcWzzpQy3mqKSh+xvP2tnNJxovd7Zsdx3t8xVWFWy6zQ\n8UmPvMU4G/smc0LsLG99eqXw3ZRwnoNkc4w1jytnvFcc5xatGWrj7JaFvHvM+wHoqCCphrgjPIze\nSVxOcOnUd4b2B8n7golv4wunX0MlyhtOp7R2nve7M91BIjYyTbkCigvTDs95drE+RN5u8NtFoy6l\n+OLbwYyhBoJZ61RxDz3FXuRkwUmyI85RFZn6eB0fPuFyJmbeAbZMQzqGqsjohk1vtkRrQ7KqENTh\nQkTebzDolsEL3atGXPJVMv3JtbbPO1A4voJ8Y6rirSEGQLZEBKVTYSlb0MKm3Rqad1+gwENNzTto\nNlcM8k4ke9Eh76CJ17CdJTfOB5WIK0h1Wc8n7loO3OVZSpvwVZt943yTslJtogY/o5N7P2U9LJgk\nqHP6GPQHB9ZDl/2JqXISDd5DZWrP+JS1xKesCbTpkLcsoVNhTrNUdFsLZR3TLFdzjnvFFKR4mURM\n8QOeLH+JkK3HQfXJW7c1Z3mM4pnNTcIJQkLEJTWiSQWQRIpUL3LdUB3BSyItNSKlh4jPWYSUzHkC\nUbkozKlWoQEplUNp24OccM+Pe/fdFHdIIZHHzHQwpnwasiT7BXMMv7b4jBZhnVBNJ5LesDxBQB29\nk+QpTwrLhy15/s7eTMl7Z+aOOhWAMaXT/HE2Y9iGitLc65fRDFilEjEFK+dr/xa+5qwqMnElzsJj\nzmNYy/F88REUZy11OuYHU6WoR5UUitKQJxDalkIqoYAR9zK9uVAlX5sDMHSFhBJnsJxlq5NCdErT\nJIKQyg28a8qF3u+JjX5lrrgqc+boed56cPDJ0hUQsFSmpWbzDqeNyY1+8hcARXIEViksCfnr6yVa\n9anETfE8O5rDxOmSN8C5o89hQsPY0H41UBP+H45d6AsLQVgqdqGJy6dcwTWzP8bcUbOrqskFMXfU\n7NDv9x1/ifftg1jn71aQA0g6wlZSjXvvnRog2/p4WJMXFj4xfm5g2+mjT0XVxfuSTsaIKRKGKQJn\nK8fkcCIi7zcYFu94nP9Z93/cv/Xhmvu9IDFbqql5B3OaV5KvLNuUrTC5xlWZdDIG2BhNO/xoVaiK\nNpcb+1jc/cfqvgQQIrMKn7c6ertXHxfAQiz18j5OtRSuUSyHydsNGDKHW3yTskNoqiJj4pN3a7KV\nhJIIpYN1NSZzqIVOyzFhBwQc19xp5ZrQd87wtlea+4Jthgs01Fiap7h542VPsDGHm/nQ8R/ANhVM\nWw/lHNesssiFbckhzTsekz1BwF0/DIARR1I16p01/yaaFyTkBqwFyRtVJyb5ZFkvi0lIbhhk7eAa\nj7xtM0beqaLVoDoR5bKN2tnlkVCx4JiF841IEsSPXYM6QQRTeZYRSQpN0ma2jQ4nKU88oFG17lnI\nl97yWT4y/QOUN84hVRJJRXTDCsUyiPHQUOwEsiRjWlaoZOqxzZP477d/i3GcGDonKIyBsxzPQTyu\nYA2OQt80h45Um3+Q5QsYFx9zPh+Y8T5/V66JZEz1a0obNu2pdqz4MI0Nzn2ZKumEeBZuJjcXHzxB\ntOUSgmHatCRb6C8OsGlwK82JJi8ILvhadaTamdQwgYXHnIcs++MXU2WSCVUkxnFwaufJ4hoBzTKm\nysyfcDafnfMJPjDjn0J9mtYqgs7M3rApOzPsv++9mZIXhNpWURmsMemblMc3d3IgjFQO2MVJ7TM4\nrmUKkiQRj/vvu7Z9JlY5yWzlHXz5LZ9leutxofPG1o/m+jmf8n7bxbqQ5u0+r2CKYzVQddHVvF2Y\nw63e30GN2nUDphNqiPzbm4+MyRwi8j6ieOz5Lrp6wqkphwoaDyz1g5y2O8kLdgyJZSuPrOjinsc3\ne8uhvCQNstCUlq3bz12PbOS3T2xhqKCFfd4V5Js3CuGkIgHNW27qJT55HWpnIA96KEDGJhYo4wjV\nAWuWbYfK5EmKEYo2V8eJ7E5mthUz60ySkuV9nHvl1aH2RE1rvw61p7kYcZ/YHD92XVL1lr5JG9/G\njq4yacXPmAR45KdvO9ExHVdq3k7U9daTwIxTeukcZDPBkBZ+ZkENasPgZm596o/c+9ctxBM1stsF\nNG+3kIax5zjmjj0RLAXN0vnLCr82s47uCCaSuE9bglhZWCVc4UMPrO/V40iK5UUoIxu+VcJdR+ya\n6yUTSbZIyP6k26AKv3HsmHX8pe8BhmNOX4yY9+wanWPASTiiashWjGLJoqUhUdNnbet+bEWQvO1S\n2iPvYKnaJI2MruskEVOxsh1YTlSxHtC8XcjJAoolnv/9T2/HtGxithCwOlLtSJJUpa25goxVrMPM\ntnGccoa3TxwrYWY7mBYkA0vxtFZJkpg35jRa4+K9NbonElMV7zovbupF1uqRFJN0Y9k737VqeX57\nQNtyMjNahb/YNWnbNqT0jnwAACAASURBVExrOZaSWaZgFJnaPLlm8Q5FVvjc3E/xrikLKrZLwrxs\nJLC1BAoqJ7YLAVRRwiQPMLlpkhfU6eLE9hOIbTsXfdf00PZgVbvebJGCM1e11CdCxzUHyLsj3Uot\n/MeZX+Rrb72h5r4g4oHnlwz8bfZOoLzqbYxPTmZUXW0BIRiBbpdTNc3mSlCgCZJ3haBuF/0I+CDJ\nu0pJXSoW2t4RkfffH/b05vj1Y5u56X+eC23/5cMb+MNT2/ijk6fc9dmokkJfpshv/rKZxc+JZTa6\nqfumV4e871y0kcdf3MOi5bt4fkNPyOddGdzh1om2yk6QkWx6Pu9ahelD/uBE0VtD7aLSbL5++wAl\no+RPtorOcN5J0lLWQRITsbb5VH8Nsmx6H+cwvdiWRGnVOZhZ5+OXA+StaiKBhy37EbxOn+rTMXRL\nx9bjFLJJfvC71aTUdCi5hr+EJ4ahy1X36Js7nWIMRh2y1iisCZIFWMSnP4fS0uMtWQHYYDzDw8/u\n8pa+ubBtKeTzdjNC2UYMWZbEGnDZ4PHnffK2EMk3Jo6qByTQ48LnHTCbG5r/2bpaaSJtihSwiu6T\ntvMcOjqcNe+uoBPQLprjzc44OlYBxe+jm9K0PuaTr6Tqjt88TqFk0NaYrE3eAW3ZM5sjfPluLedY\nYFJWHRLz/LPOulndsJDNMEkAyGaSkmbw4DIh7F7UcQVXTv8nprUI7TFoKhX37rgjSmm0jXOZnvTN\n6u77Z9swJu2n6cSWQxM7wPunXIm2YwZm/1hiqkzKuc79T29n5y4np3Z6nTjdUkgnVc49ZayXZAXC\na59PmSpMvP9w5jGcMdrv07njzwx0vur2PbQ42cckSfJIpLT2TK4c93FPuw1mF4yrI5ugAS47ay7Y\nMmec4I/DhXP9wLjeTNEjrvGdgqynTxTvUFOAvNuStcm7OdHkrYmvhfEd4t0MCl+1zOYH8oMDTJFP\nQ983GZBFeteKtoKat+dWIOxDNzPtmAP+OAQ17wWnC5fD204ZGyJvNxvckUBUVewIoViuvfZv/4CY\nLF3Tn0veiqzSE8gnPVzQ6Q/UL0a2KGtmyHReLBuUpaDmHSbkYUeDtMtpSJRANompCnXJcO7lifHj\n2aVtCpO/G6S07xhGG7Pon/DnqoC1fFlDUkyxbjemISkG/UMlLMumZJSRZJvpbZP5+GfO56tPbKef\nHpAt74PSEZnF7HLaXx8qW77ZPKb564fLKSRkSAhLREdTim5LCwWaqXZSRKzLplgj6yWmUCmXbVER\nyCHsKWMbSU2qZ1sOT7CoS6pCY0oBqoYkWSL5DAifbkX0uhtjEJfjDL94BvETlgc0b9kb/1s+IqKX\nZTuGpYhc4u4MLSkGthHnuHHNfPby2Vz/2FKkZE5MHE77Wjkwm7sJJGI6LQ0xioqF5Wb0slQScoJU\nyuCmj8zllj8+AAjTopsfqjXRChUp6t1o9JyjeadUn4ileAlUDb2QwrJt0kmVn3ziHfz48TgbSi+i\nNGRoUBspBsgqpHlrqZqatxfxK0tIkh+kqBsWsRqEI5kJT7iYPKaBK+efSl+fbyFprwim0ndNI3Hc\nS+h7BbknAqbYoJY3OqTNSSGTKMC4pk7MHuGLjqkyHQHTsV1hGscU39aHFkyjdetuFu0Sgkaw1vak\n0Q386NpzhGUFeNv4M2lPtYX93QcoV/Ctj83zAh49Td2Ih4g0SE6V91OJd501hVMmt4a01ffNn8q7\nz5rMt3+9kr39ec+d0tqQ4NZrziYVF8cGY0Nqrek+GHz5I3PRdCtErkGzubftAH5wgOPUuazrEgpR\nkLxdn/dImncwT4W2KRA3QVggPPeUsZw+YxTppMpTq/wA1WT8yFFqpHkfIVg1UpUCVaYxw3KLhMih\nYhCFkkFf0c/JLSt+SktX8ivr1oE1b6fghuf/U0zH5+1XParPT+WczvnO/mCqVD/pQn3MCc6p8Hkv\nyQo/va0lhd9WFVWSBoZLXtnIpkQ9qiLTXu8kvohp3sdZtovexGa7NZklV/O2QdWwveAmmTq5Ednx\ng7c3J0WQn+l/1JYernYkMi4JE2nZ4V3h47dJxhUM1zfsCACphIrlraUuh0zwthFnsiYKIXjJLZzx\nPnPs6SiWWKftEroiS2TKQ0hIjKpvce6gVhIVE0yVeFymPhUThUEUC1s2sJVgoJjTDyeoTZcLtLW4\n5nKfHBpiDWS1Idqakshp8fyntvqaVGstDckQZnt3kp7deiqzW+eIcUgNI8m2l9WsLqkSjym0SZPR\nd8yEgQl8aMpVBNXFUGCSLdf0eYeIXJFFwiBElbRa0buSmfDM+lPGNFV9R5WBQ9bgaN7ffg22YyUI\nam5BIvdWKQT6EkRdYAKPq3LIxxn0N4MTsJZUkSSJ9nTAOmGE1x2nEiqyJCFJEv90/Lt5+4Szqu53\nJKiKHCLaWvcUJKr4K5C3JElV7cnOto7mJLphsddZMphOxqhLxjyiDa7jrmXyPxioilxlNamteR+Y\nvINCyiuZzYPHnth+AnVqmium/2NVm3WBQlCSJHn9DL67ifiRo9SIvI8QauUZr7Xd07wlhd6MT475\nUrXm7aIuJV4iTTdDRSrCmrfN5sw28Zdjcg6bzZ3lN6UpXuIKKWg293Ihq6QSKkk1GdK8dctgW2GD\nc7AFZswj/L5Myat85Urnk5pEQJJclyURU9AtQ0RKuyZ3JxmDu9YbVRdm4YD/s0FpEfV3FZ2OppQg\n74DmrZXcQLhAZivHZFly5CK1bT9Kx26x3MPSPHJXFYlEXMEoOwJATAv5usHGGhjDhORkp9604Res\nUBNiQgsUtFBkiaw2RH28zsvFbLqme1dIkiyRWMJUfFOuQ85lCuiKIF+9GCAMxyeXNftobnL8pwGz\nbGO8kbxeYG+xy8vHPGOUr9U1JeqIyeIeZdstpOFkbnPMo+lEgg/NfC9WOemZ122nIlk66QpbNnax\nAXXvKbSkwlHESSXBuPix6LunoioyTfV+JLoLNaAdKrLkFXoQmncN8jYSnvm2crKH2r7HukCyjrBZ\n1m+/MR7O8hXsl/gd9h8Hr2NraYrPBXzRAZ93Q9zXhG0zTN6viFfBg0HNVKkIbHu1cO91pxO3U1dJ\nskqck9pnVvnjXytqEXWlUHWg/alEtQl+pIC1hng93z7nZs4c+5aqNmu9ZxAm/1cSKg4lIvI+QhiB\nu6vglsNUZYW+rK9590gb+O2m+/0DA8TqlgYtaWaIUIOat9zcy7J9ItLbKjnLpGJlYorsmM2dNddy\niua0Y/JSqoO9MEVO6sZ4AwOlQTQnA1le9zOvGb0TsE0VJSau35sp8v/Yu/P4qMqzf/yfs81MJpls\nkAAJ+yabICgo4i5Qt69WWxUXcKlaRVu1daFUpbUPuFT9Wbva1trqQ12hllddeLpp1YLWlcUVtAjI\nkkD2zHaW3x9nmXMmM5mQZCYZ5vP+h8xkZnLmJMx1rvu+7uuOWevL7eA9uXqMeVyl+/CrDx7Gqzv+\nbZ6npJ7Y/knrzKBmNUZxFy/ZhVRCoB0DyvxQDc0zbG533rKDrjkkbZ6rcLsrWJTvRWNwM3a173Iy\nd0kU4VckaBFX8PZUrsdR3xSBz96yUo45vxO/5IMkCOb6Z8mcKxdFc7cjuwMUAGiqfYFiPq+4OLGk\nx+nnbL3fiNaGmGiNnESCieYeVrOaFr0egZB1fK7gXR4wA+nKj58xHx8NYGBxYs4x4JfNoXMAJdoQ\nCLGg01TEzmz9PslsEqO5A5V5UWF/gNt75IiC0GGeWBAEnFJ9DtQvx6KqPODMwbqzMzkp8/YOm4uY\nV3YhYlumIbLpaMS3j4PSNNK5uEgOIgBQXtJx7tGdOaWbUxUEAQtGXYzohzM7HFcyRRZR2mFnP8HZ\nIcuI+Z2LG89FQYoe+r3NfUHiHjbvjeAdjWnmErqkQCUIAr459RKcMvLkbv+MVFLOb2e4oHFfdLmH\nsv0phs2TL9DSCaYY4QAS9RoAg/dByT1s/t6n9dANsxfuftd2llu/bEJMtTJcw0BdY9jJAPeXJ/aA\n1ttLrAIq8zXNDy8De/WtqI/sT/xQV/C1h5dlQYbeWG0uMSpuhqJ4h819QgClwSLo4SDEUKPzGu7M\ne2d9G0q1oYjpcTyx/jWs/2C3s7etumeY9fqJIri6pjBiVqFdsbWOcnRlrbmOdsBufNGyHau2/MU8\nUCt428PmghL3NOZwF0KVyFaLVCWC0lJ7eU7iP09Ts13oZm1VKCUqsdvaBGCH+SErVdShrug980nW\n/2NZMiuW7QsdqWq7N/OW4tjXHIEMa3jWmuMHzEzTzLytD3YljrgRQVxXUe4aQraXfCnDzKYgpSF7\nIwvZmUqwq5TbjVZExCYYutnD2mn5GPfDiPuwH19ik/oP8xQ0JqqI7S077SmX2JbEOmDAXPs/sMgM\n3pE2H9o3HO08xh42d9bhC4lhUfu47OBk/32bc9YdPwztAJuuGtcdJCVR8BSs+WQRgwODoe0fAqO9\nFOquMdBVn7MbXjDFvvbuzMo5Blfm7Q48/qQ51YmV46C3mFXlyRciboospXyvV0y5GPH35gGaL2Xm\nndziMxv8nmJAd/DufnAZ6JqKSHXBlC2ZsuxMz3FPz9gXAuky784Up/g7AwBZTrwWg/dByJ15P7Rq\nA155dyfuXvmOp9HK8sfeRn2zOTcc0+PY1xxFZWkAJQEFQsRd9GP9J7IztiIFUtUObCsyd/s5acDp\nAOBtB2oF4UvGLwIMEXpbGUR/GIZoNfiQzbaZPsmHoF+GVl8LQdQhVe72PB+agoaWKN57y/xD/vN7\nr+PXaz7Axm3m4+zgamgydMEMmvubo4gLVvC2Mm9REJ0Mz3OeUvTrdg9ZuzPvErnEeZ8lQSvwuTJv\nLe7q2Caaeyy7s57wl0M9VePun2suvZGgt1ZgQukkSKFGSFWJZXSxz6bCMIC4nZl7Mm+/uYey9f7E\nYDPaNHsLy0TWO7bYWspTuQcQVZSUJC5A7A+BgUHz8a/sfx5hcb815SGgusIOgmaTFA0xRIw2xHeM\nxfTBiTXqdvAGzPXtRpv5evaccNAvozpoBqq2Fsks7LOCS0u7+Tu3i3wG+BJroO3gbVfXjq01f85h\nYwc6owLuoFhZGoAAYGhVx9854B16lCUhkXlrZuadnDFqmp5YrpNuODPpQzlV4RLQ8QPXPUfaWYGX\nnbFVliay/AGl5haVMsy/02CK4J343XXN6CHm//3Dxg3M8MhERul+T6LYu5k3kH4IORvsn+WTRQyz\nKtyryjpvhuK+6PLMSSuJkbVU3+/KcSSTPXPeuQverDbPESczKWmAPPQTfLjTu7ymTdgLsbzOmeON\najFEYqq5jlY30KaJEACE6meiWbaWFllV1MGA7GzacP74ryKyx9zooXaIjMtPnAXdMPBKXQPW7/3M\nyXy11lKIZXUIi/sQ9I81s1PV3NtWFAV8e958/PKjTzF1qoh3/55ocFLiC6IZifWPdr/o3U1WW0+7\nGMfOOiUVMVWHjihEeCtSJ9YOxsdNiZaR5vPNDz17pACAOd/tt+daXQ1GrOB95IwifN76mfVzXWug\nnUYuGgTr/BiajOKAbA25CmZBmL9jsxnJNSw4o2w2Pmr+wNmJKbLhWHO/agDtLQJQYgV915y3JEWc\nJVRicTNaVfPnuzPvb596PJa/vBX75E8hKDFUlgWxwzpGe8574UmH4uebXI1rrO5XQwYU47yTxiIU\n9OGdLwdgXcM/UOoL4ZRDvo5h1SFcEDYb5OwT/us89fCRI3HWCWbryTsunYn9zebWh9VNZlCwh8JP\nnjEUf39nBzTdQHFAdoYd506agj98bG7KcsnJhyEkDMCU0WbWfszUIRhUUYTRNWaf7GWXzkSFK6hV\nlRfh9kuPwODK1FXInjlvSUQsrsEwDKfaPDljVDU9MSef5kP1vmuPxusbduGZl825fk+Rmiu4JQc0\nd5CXU2Tw9187B63huJN1L7t0JprazIs+ewcrO4ja2Zq7u9hti7xVzJnMnjIYA8sCGF2ToiNZkvsW\nH43G1ljSnHfXC9Y6M6DU3BfdMHIbvAM+GcsunYnykB+KJODLfe2oTXMRaHNfdCkpRlnSFay53X/t\nHLz18V488Tdzu9p0GXqqi4NcYPDOEbswTSzdB6m0Ac2tewAkrh63la6Fuyg3psUQi5vLqEQB2Cuo\nKJaLIDQMhVi1y1xcJOowAAQUGYJsZn2TB0zA3z7ZZ3bv8oedtZhavVWQ5jOvnu250jZhPwRBgCDH\nYaiJhgMTBtdC/FhE3JpntTPvimAJmvfHzbXWmuQUpe1vbzH/muxqcSdwxs0GNIpVze5aQlJRFAK8\nsdvJrGPbJiIweb35GnIMorVlpN6ayFxDivke3t3/Nt7d/7Z5pyvzdobQ5Rj8483vG9EihII+TB0z\nEOs273aCfak+BOMGV2P9f8zzJEuCk50FjUrobaUQi5s9xwgATc0wg7ccd4oI/ZIPoiA4PbvF4ia0\nqOZzy1xz3n6fhOpQGfaFzfqD8jLRXLalJ4bNq0PeavD4DrOJSHFAdrLYE8dNw4mY5nmcX5FQWQpU\nxkc79w0vH4RqK3sqtiqFAWDW4MPx7KsfQ9s/BJNHVngyQ3e2NW5AotBt+qihngsxURBwyPBEtfWI\nwR23dxw5OH3w6ThsbjhD58mZt98nQdONRJerNMOZpUGf50PeHdB8nmHljnP0smQeQ6oP9oqQ31lf\nDQChoA+hoLeRjP1+3BcCK+bcDkkUUaIcWMFa8rntTFmJH2VJ8/2pmrR0hyyJqAz5sa85mnYIOVvc\nf0/2KE9n5DRLwVIWrKW4QAPM3/PIFH/Hydw1BRw2Pwg5w+ZWT+WInmKHKxd7yZdfkcwPJ1GDIvoQ\njWuQkpYYybIASUlsU1jXGIERC6BdS6x7tVtzhvzmB669WUMM7TAMwwrePkSsHbxkUUaFvxx14X0Q\ny/dCHmAPi7u2WlQVZ/mUvduYMyft7GGsojUchVhsZuYhV+FOyrWg9rB7WzmiH1vLk5QoxFAj9EgR\nEHd1B/OluPp27fhlX0BIFXsgKHFIrYOh7jSDn7OUyPp9+EQ/Lp9yEcT95m5DdsEaYFZdq9aOSoYu\neLL7BmsBgFi6z5mXtzd+QDwAI+aHEGzGjlZzyL0maSlSib1LkRxDaYld7Z0YNncXOk0UToTeYI6q\ndDXzce+6lG7tbUD2I/blKECXMLC8KG27R3exXbHcvXW86aQqWItZ65cVyRu8Az4JqmZkHDYHktY4\np8mQUs2P2wVv7u1dD4Q9kuD+PZX5Qx365OdCb2XeQOJiLpeZd3d4Mu8U1eBdybzNx2U+X+6Lg1R/\nS9nC4J0jTsGatYGCs/uTLWlLQ7tHud9ndmkSJA2KYG4Dam9q4ARvSXSGtQNyAPWNYQhqAO1qO1Td\nvD+shiEKIoKKGbTsIdKI0WbuYiQYgKZ4CuiqigagOdYC//h3nPvawq41yZriFGm1WNXmRofMW0Wj\n8hnE4hZUxMd4AkiqYOLeYcoZQg81QJDjHdbRFrn28lWsYUnPPLp1DGKRtca8bZIzn+tklFa2LMIq\nHrP+I8qS4BS6tEfi0PbVmIFb9cFd6mq0l0JvLYNUXg+p2mxp65f8zsWa3h6C6I9g475NKJKLMCxU\n63kPIcVV+e/XneO2P2R8kmvNtpgYdTiQzOeiCeeiSC7C5AET0j4mZm0vW1bs82Qi7vXSgiDg4gnn\n4uyxp3d7HW86SoqlYnbzEZ8ieoJOQJGg6ZmHzYH0WZV7Pa6U4jF2Zt3Y0vlFdjr2h36uM9RUpG4U\nZ6VjX8wV+/v+fXXGezHoyox9nS8VS9aF2J2x8U22MHhnQTiqoqXduyuY0yXMyvTiRlJ3LtVbgOEE\nb8Xa9UtUAV1GLK5Bttbl2vPjih28NRkCBNQ1heEXzMBoN2Zpj4dRJAcgSaK5VCfuh2EIaFWb8ceP\nVgEAtP2DPB9WA4OuTRosEVenOMOqKPcrIiL2Bh3OnLcVOBUVMdnMumvh3bIwOXifNOxYs2DKZjVZ\nEcvMPa6T23C650EXTVqAG2dcA3XXqMTx6e75bxHlYiLrtT+c7WAfMMzXtocYJUl0/qO3RVRA9SH+\n3ymIb0/sya3IImCIiG01h6ztna38kh/2SgB7HXZYi2B8+egOGzKU+q3aASXmbCBiaHLK5TElUuLi\npegAMp+ja2bivuN+2GlbSltxkeL5MEquDp9dMxNzhx/f5Z/dZUnLxlQtfebt90nQNHPY3C4sTCdd\n5uS+v7Pg3dDazeCdIvPuK+5h855edOVL5q2kec+pMu/OCtbc2/Wm09MLou5i8M6C6x96Ddc/9Jrn\nPrt6Fk7w9gZ3Q9CgR4LmGlZRcTbZ8CsSSopkCJKOPfti5iYMgt061GroIgnOhhRtERXhqIZia3/h\npqg51xpWwwhamar5QWj2zd7eth0fNXyKQfIIaPW1GDwgEVA9OyxZhg8yX3dAqd8pShs62Oc0QnEy\nb6dtpwpVtIbsFe/8kXsI8VuHXdlh/99xQ8xWlfb/PfcmAYD3P+DQkhqMLR8FGCnmvGF1RLOqdodV\nlzgZROyzqYhvH4ehhrk8ys4AZVFwisbs4VmtvhbaPnP4fNSQUqdHtxEtcvrFA4Bf9jkdLY32xEhA\nqsy3LJAI3s7fhC55AtL/G30Kjhx8OAJS9pbq2PPcQyqDnkrsQTnaaMHdrEiWBOiG4azEUBTJO2yu\nmHPebRHVHJXqJCBJXVjDW2o1jXFn92NqzIu5IQO6N8ztVyQU+eU++2B3S3Vx0l2DrL+T0mJfhkf2\nrXS/d/v3ka63eTK7F7zYyd9YV9eJ97b+ffmUp+xCG8MwnA+WRPA2/9XQMXhDC0DdNQbV46LYGfkC\nsLbLnDV5IF54C04wcipXreCtSCIgxWFEfE5L1XJ/KRoANESbMArmnLddLFUe8mPP/nZz/tgXRZEc\nwHdnX463yxow3bUcZXz5mMTxqTK+MuJknHTUZLz7aT3iqo6nt5hrzysrJWyPxc1kU1NwzNQhEMsF\n/CeyCZKswfBFYBhCh0zbfXt8xRgIgoAbzp0Gnyxi9/52zJo4CDe/9ifnfert3jluWRKwaOL52NL4\neYcLjbISH5paDRiGGfyrS8px+lEjURr0YfaUwSgOKLj27EPx8z9thLprDIQae7g8kXnbRU32nuTj\nhpbhzGNGYV9TBDPGV+FnqzbA3GFcgN40EKK1I5tn2Nx1wTFj0FQkq7CaqMiDvsCGfebPOe+YyZ6i\no1NGmu1qX3rjC+e+AaW9u2/wLRdMx0dfNGDK6AGIxTV8/YQxB1Qo1VPupZR2sLH/litCfs8oix3I\nW9oT+5ink/yhe+uF0xGNe7OpMTVluPTUCc4GGwBw8hFD4fdJmDHeu/NWV10wd5wzrN/XejN4H35I\nFRbOH4+jJg/utdfMhuRs+vuLDkdTa+Iz1/130dn5GTE4hEtPnYBDhqcfteqrCzQG7yxStcSmCppm\nz3mbHxya4A3eEPREYxIkdtzy+yRnq0l7DbOdeTt7RUsCdCEOQwtiZ521UUdxJT4PA/sjDeZuZLrq\nZN5V5QEzeFsXEjXFQ1CsFOG4ad4sa3jpUAwOVmN3+15EP5yFY4+ag1DQh+Om1WD9B4lK7WDQgICw\ntbm9gLOPHY09cRn/ec8cNheUKBD3IVDi/aAt9lQrm+996hgzCE8YYfX/1v3QxXarUKxjRe+RVYfj\nyCGHdzj35SV+8z+rIQCCgepQGfyKhLlHJPp6H35IFYJ+Ge1R1cmU7Yst93CsnXkfOnoAJo9MVH+7\nq5zj28cDhoihlRVQRNmpcTDCJdAjRThhzHTPHL1znEWJC5KdrealwNETRnV4HODNEFJ1EOuJytIA\njp4yBIBZiX3aUSMyPKN3uTNve5jX3rSnqjzgyYrt77dFVFRXdF44l5xVpbsYOW5ajee2KAgd7jsQ\n44ZmnqLIld4M3pIo4sQZQzM/sI8lz0PbIympZJpKyPR30JWitmzo+zGdg5j7Cl/Tra+tYXNnj2Xz\nljlfavfztpc7WTtuOXt0W8HSZ+/HbDdOkTSn4GxHnVn1XVtqZtANkUa0W/PRRdY+t1X2jkuy+bo1\nJemvom8+4jpEP5wJI1zq3bQ+oCSGyJUwRH8EmpUZK7LobK0nyCoEXxRGzO/Zlxfo2s5DJa3mHLPe\n2DED6uxDKRS06wLMoFDiSz386QzJG4bntuyZ844797l55v00H+LbJmGYPsN6Qet+Q0R0w3E4b/xZ\nKX9+kc/nFA8CZuFdukpud/FVLqtac8Gdedvnede+xI5x7mFz9+890/RBbwaufCX1g6H7XMvlUHa6\nTaeyrfB+qzlkL7sCEsPmguBu2WmxN9+wMm87wxZE1QreiblQAAiIAc9rGFYWb2gydlrBe0SlOV+8\nP9Jo7kcNIGgFVLvNYeyzQzEiNAynjpyb9j0E5IDTKtL9HyIYkJ3g/Z+ItZtYuGPwhq/dXI8eD3To\nhdyV4F0ePgSRjXMQ+++UDt/r7EMped/idEt07Kvu5P9/dntUIJF5JweCVEU79lW49+VStwwFzGVP\n0Q+OcgJ4mb/jDlm2rhTP5Bv7nRquM2af50TmXeQ59+7fe6bCqT76XO1XCvECJpdD2U5ilmMcNu9F\numF45lI8mbfmLVhzb7cJwargtoKz0SHztrqLWXPefsneDMMM3ppgXQioMnZY2/UNG1AJn6hgQ/1m\nJxjYa4ZDRebws948ELfMPK/L7y8580bS7kh24xdFFlFkWM1gfFZns5i/Q1WwLMo4YegcDAqmn1eU\nZbFDoZqts0KT5Cvv9MHb/Dd52FwUEsHbnitLfs1Uy4DsD8p0u8h1PE4RRqzIXG5WuRdiJzsu2Mv4\netJoo78RBAGGYSRl3lbw3tcOvyIhFFSSNjFxZ96dz3l39fdwMGPwzi7nsz3HDp5PgT62e387rrjn\nn/j7267+1/HEsqrkgjVB6ph5G9awuW7vNuWLwO9zZ97mtZZTdWy9hu7KvJvaYigt9iHgk+GTzCD9\nft0mAMCospEA9OeBHQAAIABJREFUgPJQ9ypFk5tcOHtuW/RwCLIkmPv/Wpm37rOat8T9Kfv+njv+\nLBw39Oi0P7OzZRzJnbHcyoq9c8IBOXWBlz13XGQdWyITTPS/brcadSRn+ikzbyl1Jp+JfcEW19MX\nOdmvOXxQ560h84m9JMtd4S675rQHlgc6jES4f++ZMu+DbXqhOwrxHOTygqWvLqaZefeStz7aCwBY\n+ddPnPvcm444QyuiO/M2AAiJPautYFgSrwX8m6AM+wQ++WRnztvOvINyEIh3zLxrKspR7h/gVMi2\nurbpBIDRZWYR0qSRlTh99ghMH9e1Stprzz4UX9a3ej4EKkJ+nHroNLypbcWYkrF4+8P9MNpKofgT\nFfHmkjd7NzJft1oHdjZ3lep7t1wwHe9vrcfXjh8DUQT+ZT9WTP2zrz17Cl5Y/wVOn20VaLlesqqi\nCKOGhPD5LnP0IPkDIdV8qzNsbkVaWRJxwdxxad8DAFx51hSsa9iBrZFdnuHjZGcdMwqqpuOsY1IX\ntOWj75w/Df/3n+04+fBEEdSx02rQ0h6Hbhg4ekqiHuPCueMQ8MnY+mWip26m4D24MohTjhyOKaMq\nO33cwcyvSDhzzshO29MebIr8svmeh6R/zxfPH9/lTUk6M2N8FU6cXotjpw3p8WsdCAbvA7S18b+o\nC9fjqCHezQXswCYEWiGW1UPbMwLRlJm3GagF0XA2FnGG0q3g7YsOQoV/KBqKdwBSvMOcd5EcgBAX\nnNakLaq5DehXpo/F7JpEj+sLJ3wNL3z+NzRGm5znAeaQ8NeOTywDy+TwQ6pw+CEdA/3Xjp6Mq6uO\nwsdb67B+rbkft/sqNCgH0BSzh/SVbgZvMem22XMaSD1sPmFEhVOpfv5J4/Avc5dMKGLq4dXqiiAu\nPTWx/lpAYthbFARcd85UfPfnr6c8lmCKLlPJAf74aTU4cXpth8e5nXncGEzYfiZ+uWEfFhxydtrH\nBQMyFn7lkLTfz0dDBhTjklO869/H1pbh21/vuKzOXimwbXeLc1+mYXNBEHDeiWN74Ujz21ePHZ35\nQQeZTO/5pF6qmpclsU/+XzJ4H6AH3vkFAGDW4Bmebln2XHdgqtmcJdJa4Q3emrdgDYCZfetyIhu3\nhs1jcQ2KHgREQBOirszbqjZXJBTFi9BqBe9P2z+CKIiYPND7ITin5kjMqTkSb+95DwNTNFzpLe4P\nUPeSnqASRFMssZtXd7bLSw6YPlmCqplDy501TrAdWzsbr+5ch9HWlEEmiepz89+yksQUQ3Kmn7pg\nzTts3tWGVhWBciyddWPXHlzg3Bdt7o0/iAoJ//K7SdVVZ04ZAJKnWARRSxo2TypYgznsbcQDEKwm\nJPYcciSmQdB9ZvAWI4iq1lIxe523LCIoF6FNaoXgb8OeyJeYWDnes4mF2+GDDuvRe83Ep4hmP2rd\n8GTe7mpyo7uZd9J8kt8nOXPQXWn1eN74s/DVMachIHdvXbS3mYP3WFIOm9tz3vbwd+FNN2ad5ClY\n40cYFSYWrHWTmlRY1CGQGELSsLm9zjuRedubeiSGzc3gFo1rEKyGJHEjirC1TttemuWTRQSVIkCO\nQRpgNvaYOWh6z99UNwmC4HyIuueQPOuVXZttHIjkbPdAd0USBfGAAnfyum+3mKp5bqfaijIx523/\nfEbv3uYtWOvfG2QQZQuDdzfFda3zBwg6/vi3T7Hps30AXJm36FoTaFecJw2bb9vdgi92mvPcUSOM\ntri53tVu0qLIEkqUIATRgFS1A7IgY2rV5J6/qR6wP0QVxTtsbjNUxbOTU1clD5tne79cZ847xfda\n2uOe26kL1rpXbU5dx8ybiMG725Izb7ufucMKyA88/T6AdMPmKgZVFGFghbWHtWvplb0dZtQIO3tx\n25m3IotOYBT9EQwtHppoitJHjpo0CANKAzh8fLVzX1BJtAOVoXSrjaA7eI8cHMKF88b37EAzSZEo\nf+/iGZgwvByzJ3v34lZkEbMmVnsKopKHzZl4976JIyowqDKISSMrUFHau21iifIFL1u7SdW9WVgs\nufuVNY9tf3inK1i7Y9FMPP3uK3izHYAuosgvIRzVnK0129VE8LaboiiyiGI90XQk5Ov7db9nHjMK\nZyYtYXIPm/vl7q0td+/zfPslR2R9swdnnbfr1zRuaDluuXBGx8cKAq4+y+z89vQ/twBwVZsbicdQ\n7xo3tBx3XXVUXx8GUZ9i5t1NquEdNjdbV7rms63MuzJkZsTJvc0Bs1GLTxFRVGT9GgzRqdy2s+zW\nWBva4+3wiT5nWN0niyj3JdYvhtIUqvU1d8FadyrNgUTBmiSaLUaz3Xwh0XGte+Pe9pJBnfVqRJRF\nWc28V6xYgffffx+CIGDp0qWYOjWxdnPlypVYs2YNRFHElClT8P3vfz+bh9LrkofN46ruZNsAnK8H\nWMN6qea8BVmFJIoIBqzgrZutIOubIs6weVu8De1qGAEpALs1hSKLKJMSwbs0zaYbfc09593duWq7\nOMkO2tnecEBI7pd6gOSkJi3MvIkoG7KWeb/55pvYtm0bnnrqKSxfvhzLly93vtfa2opHHnkEK1eu\nxBNPPIGtW7fivffey9ahZEVyG8u4qnn7lVvBu9Rqv6m72qMaqnnNJPniePHzv6MdjQDMOe9ia39i\nSfdBgIDWeBva42FnO0/ALFgr8yeCd1mgf3ZOch9z8qYkXWVn3nZGm+3t9xLD5t2M3slLBhm7iSgL\nsvZJuG7dOsyda+5WNWbMGDQ1NaG11exzrSgKFEVBe3s7VFVFOBxGWVn6/Vb7Ql1jGI+t/djZDjJZ\nqszbvVOY0/LUCgLujUnsrFoYsAN/+XwtXtn5uvVY0Wk6IUkiipUgGqNNiGgRTxaryKI3ePeDOe9U\nZDExsNPtzFtK7K8N5K5Pc7eLxa0n9tU2gURUGLI2bF5fX4/JkxPLlyorK1FXV4eSkhL4/X5ce+21\nmDt3Lvx+P04//XSMGtV5v+aKiiBkuXeXCVVVpZ8rXrHyHWzZ3oiyUABXnNVxO8rikOJ5viCJiXXb\ngJN5y4qEqqoQJFkCYEAQAD3uAwLtHY+nrBhl1hy5IokYXl6DD+o+BQBUliSC9eDqEAJFiYA9fNAg\nVA3su3nvdOfRCNYC7wB6OIjSEn+n5zudygpzSkCWRef5AZ+E8cMruvV6mVx46kT86JE3cN68Q7r1\n+iWhAKqqQrj6nKn45aoNOHXO6C69TjbeS6HhOewdPI89l4tzmLNqc/cwZGtrKx5++GG89NJLKCkp\nwSWXXIKPPvoIEyZMSPv8hoaOwa4nqqpCqKtrSfv9fY1h69/2lI/b19CCOiVxf2tbzDNsPn54CB/s\nBMLhOOrqWtAeiSWK1TQZhi4msnPLVacfin+tM3+uKAoY5B+ED2AGb9lINKNoaQ5DiyZ+dVq72Ol7\nyabOzqMAH04oPh8vvl0HjDO6dYzhtqj1WnCe/7MbjoMgICvveVRVMX57y4kQRaFbr9/cHEZdXQtm\njhuIw7v4Opn+FikznsPewfPYc719DtNdCGRt2Ly6uhr19fXO7b1796KqytzcYuvWrRg2bBgqKyvh\n8/lwxBFHYNOmTdk6lG6xLzbSjdJ2GDbX9KRtPs3MW7NeJ2q0QvBHrBcXALXjdZMsys7wuiyJqA3V\nON9zL7tSJNFTCFWi9M9hcwCo8g0GNB/8Svf+1Ox13u4qc9GqPM+W3hqaL8StGIkoN7IWvOfMmYO1\na9cCADZv3ozq6mqUlJhBpra2Flu3bkUkYgazTZs2YeTIkdk6lG4xUqzTdY8edChYi2uA7B42N7Nq\nOxjvqFqDwNRXrRcSYcQ7NpdQRBmqtaRMEgUMK0kEb/dabrt/+IiQuctSd/t254IdfANK9wZ5ZDm3\nc91ERPkga8PmM2bMwOTJk7FgwQIIgoBly5Zh9erVCIVCmDdvHr7xjW9g0aJFkCQJ06dPxxFHHJH5\nRXPIvdRH1VX88aNVmO3aBlQ1OmbeYlFiqMQQzO87Veae1xZgtJVBLDYff+GEr+HThs9RVTQQmlYH\nwCxYqykZjONqZ0MWZcypmYUnsN45JgD47uGLu70eOVfsgjNfN1qjAole6WKWq8x7C+vUiCgXsjrn\nfdNNN3luu+e0FyxYgAULFmTzx/eIu8nGxvoP8cbut/HG7red76tJvc1jqg6xsinxfGgQBQGaYeCL\nvc3eFzdE6K3lQPUOAImtO4FEm1VZFCAKIs7vZH9nScxun+/eYGfe3a02l6zny3mSeff3iykiOjjk\nRzrTBxKZd+pGG8lz3jEtBiHYAr3NrArXoEIUBei6gR/8YZ33yboIvS310rhZE83+2ccfVpPy+/mm\nImQO6Q8o7V7v9UTm3b+D9xGHmPUcIwf3zzX3RHRwYW/zDARBgF/s2Jc7ntzbXG6EIBjQWiogBJuh\nG6qzx7Wn8xoAASLu/8ZXsOaLCMaVj/Z878hJgzC2tgyVKTZc+NkNxyY6teWJMbVluPvq2RhY1r3g\nbQ+79/fg/c2zJuO8ligGlhVlfjARUQ8xeKfhDJsLqbt6JQ+bq4K5lE2PBiHpkpN5R2Kad/03zD2m\ny0sCWDTp/JQ/e0CaQJevexdXl3c/oLl7m/dnkigycBNRznDYPI3EUjEBmqF3+H6HLUFFa9vOmB/Q\nRWiGBkkUsLehvUPmDZ2nvavsXuH9PfMmIsolRpE03FXDWlKWDXirzQ3DgC5Za7jjPhi6BNWIQxIF\nGAY6ZN727mCUmZ1550vBGhFRLjCKpJEp845riYC8e387oJidwIy4HzBEqIaayBalpODP4N1lSp7M\neRMR5RKjSBruOW/N6Dzz/vmfNkFQYgCAUn8I0BKZNwAIHQrWGIi6yqdIkCUBRT6WZxAR2Ri803A3\nadFTDZu75rwjMRWCEoUiKvjhJbNRO6AUcT0OwT67ycPmev9fn91fyJKI755/GM4/aWxfHwoRUb/B\ndCYDM/NOVbCWCOiabkDyx1DmC6G02I/yYDF2RXRIkpW+JxesaflZNd5XDhle0deHQETUrzDzTkN3\nNWlJOeftWuet6ToMKWoOmQPwS+YabdGa6xaS57w1XjMREVH3MXin47RHFVLPebuGzXUhBggGQtbu\nXgEreAv2RiVi0rC51rHpCxERUVcxeKdhrxRLV7AW01QnOzdEs9K8WDG37fRbu3wJaTNvDpsTEVH3\nMXhnIKYpWPt8dyN+9Zy5B7kmmpXmQTt4S1ZmbQftDpk3h82JiKj7GLy7INWcN0QNb31sbt9pWMG7\nWDaDtzNszsybiIiygME7A90wUg6bu7umGZJZvNZh2FxUARgQ/OGkF2XmTURE3cfgnYFupMm8reBt\nGAYMKXnY3NoRTFQhDdoGsbjZ05iFTVqIiKgnGLwzMAwj5Zy3ORRuQDcMCLKdeZu7StnD5pBUSKX7\nAQCXTlqQk+MlIqKDH4N3BuaweYrMGwAkFf/e+SaU2q0AgGDSnLchqBCKWmDEFVQFB+bkeImI6ODH\n4J2BYaRYKmavAZdUPPnpaufu5DlvTYpADIRhREKQRc5zExFR72DwzkDXOxasybDntL33FyctFYvI\n9eY3wqEO+38TERF1F4N3BobRcT9vUU/MaeuRIud+RTSXgNnD5mFxHwBAiIYwpHgQJMOH+M4xEFiv\nRkREPcDgnYFhGNCT5rwF3cysBUmFEU0Eb8GKyvawuV1ULmh++CQfZukLoe4cl/2DJiKigxqDdwYp\nC9bsJiuSCgjmBPjo8Hzn285SMYtgPV4wmHITEVHPMXin8NH+TyH4zMYqqQrWjLgVjK3gbRgCSvUa\n5/uKKENxFagJujeYExER9QSDd5KWWCt++t5v4J/2CgAz8/7vnibPY1S79kxSIQgGYAiQRG9WHfKF\nnK9FnbuIERFR72HwTtIWbwcAp6hM1XTsaWjzPMauX7MzbxgCxKQzGfKVOF+LOnuZExFR72HwThLT\nY57bcVV35rVtumadNsnsXW4Gb++pLHUFbwHmELoB7+sQERF1B4N3koga8dyOxXVAMAvWoh8dgaJw\nLbR6c37bnXlLSeu/Qkpi2Dz5e0RERD3B4J2kPSl4x1XNybz15gEI7T0aajQAABCUqBO8haQzWepP\nBG8hKXgn3yYiIjoQ7NmZJBz3bt8Zs4bNDQMABLRH44Dqg6EqEAJt1lruFAVrimvOW2SwJiKi3sPM\nO0lY6zhsLgg6YJinqj1ilprr4WIIgTAEUYNhCB0CtGepGDNtIiLqRQzeSbyZt4G4pjtD4wAQjpql\n5kakGIJgQPBFAUPskHlLouR8bX+L5WpERNQbGLyTeDJvwUAsrgGCDlmUMLqmFLo5fg4jXJx4nCFA\nTMquJw+YABgCYtsmdPgeERFRTzB4JwnHXcFbVJ05b1EQ4VcS2bQeDSYel2LYPOQrwYTGi6DtGclh\ncyIi6lUM3knCqmvYXNSdanMR3uAN3fV1ig5rAKwit8SwORERUW9g8E4Sdi0VEyTVWectQoLf5w7Y\n7lPXMfMG4AyxC4zeRETUixi8k3gzbw2abkBwhs1dp8u9Q1iKOW8gEbyd2M2KNSIi6gUM3kncTVoE\n0W5ibgZvn2vY3NATpy7VUjEAmDKyEgBw2NiBnvs5BU5ERD3BJi1JYpqrt7lkB28doiBBkdJn3qnm\nvOfOHIZDhldgWHVJh+8RERF1F4N3kqh7Y5KkzFv2BG9vIE+VeYuCgBGDQx3uJyIi6gkGbxfDMBDX\n4s7txLC5DkmQkoK3O1innvMmIiLKBs55u8R11bttp5TIvCVBhCy5h8q9WXiqYfNkrFcjIqLewODt\nYs93S4JZmCaIGiCqEATAJ/o9mbe7YC3dsHk6zNGJiKgnGLxdolbwDohW9zRRg+Azq89LlNABF6wl\nG1xpvu7omrLeOWAiIipInPN2iVvFakVSEG1aCyCp5sYjAEqVEGQxdcGakWadd7KTZtSiOCBj+riB\nGR9LRESUDoO3i5N5C2aGLEgaBMXMvEt9pZCTsm33110ZNpclEXMOHdJ7B0xERAWJw+YuMavS3O8M\nm6vOsHmZrzT9UrE07VGJiIiyIWPw3rp1ay6Oo1+IWcPmPqMIgJV5W8Pm5f4yyHLP5ryJiIh6Q8bg\n/e1vfxsXXHABVq1ahXA4nOnhec0eNldgBm9zztvMvCuKyrwFa8jc25yIiCgbMs55P//88/jkk0/w\n4osvYuHChZg4cSLOPfdcTJ06NRfHl1N2gxZBV2BoIgRJBXwRGLqIkFKMqBRJ/URD5LA5ERHlTJfm\nvMePH4/rr78eS5YswdatW7F48WJcdNFF+O9//5vlw8stO/OGIQG6DLG4GWKgHdq+wVCUpA5rbhw2\nJyKiHMqYee/cuRN/+tOf8Je//AVjx47F1VdfjWOPPRYbN27EzTffjGeeeSYXx5kT9py3oUkwNAmC\nYt6v7hwHRfL2Nvd0WwOYeRMRUc5kDN4LFy7E17/+dfzhD3/AoEGDnPunTp2aceh8xYoVeP/99yEI\nApYuXep5/K5du/Cd73wH8XgckyZNwp133tmDt9E77A5rhi4CmnlqDEOAEQtAlrztURXZvc5b5Jw3\nERHlTMZh8zVr1mDkyJFO4H7iiSfQ1tYGALj99tvTPu/NN9/Etm3b8NRTT2H58uVYvny55/t33303\nLr/8cjz77LOQJAlffvllT95Hr7CXihmqBEO39u6OKwAESJLgqTZP7rbGYXMiIsqVjMH7e9/7Hurr\n653bkUgEt9xyS8YXXrduHebOnQsAGDNmDJqamtDa2goA0HUdb7/9Nk466SQAwLJly1BTU9OtN9Cb\n7DlvXXNl3roMSTSryd0BO3nZGIfNiYgoVzIG78bGRixatMi5fdlll6G5uTnjC9fX16OiosK5XVlZ\nibq6OgDA/v37UVxcjLvuugsXXHAB7r///u4ce6+z57x1VYSzFEyTnEDtnvNOzrwZvImIKFcyznnH\n43Fs3boVY8aMAQBs2rQJ8Xg8w7M6MgzD8/WePXuwaNEi1NbW4qqrrsLLL7+ME044Ie3zKyqCkGXp\ngH9uZ6qqQp7bwqfmMUqiD7D28jZ0CQFFQlVVCEUlifddFFDgXMIYAgYOKO7weoWiUN93b+I57Dme\nw97B89hzuTiHGYP39773PSxevBgtLS3QNA2VlZW49957M75wdXW1Z7h97969qKqqAgBUVFSgpqYG\nw4cPBwDMnj0bn376aafBu6GhPePPPBBVVSHU1bV47mtpN39GuN2A4Lf28tYlSKKAuroWxOJa4sGu\nixEYApoa2+EvwOQ71XmkA8Nz2HM8h72D57HnevscprsQyDhsPm3aNKxduxbPP/881q5dixdffLFL\nmfecOXOwdu1aAMDmzZtRXV2NkpISAIAsyxg2bJizTnzz5s0YNWpUV99L1tjV5mocTuYNXXIqy93z\n3JJnqRg7rBERUe5kzLxbW1vx5z//GQ0NDQDMYfRVq1bhtdde6/R5M2bMwOTJk7FgwQIIgoBly5Zh\n9erVCIVCmDdvHpYuXYolS5bAMAyMHz/eKV7rS1E9BlmUoaqAPedtqDJ8VtB2B2jJ9bVhCDBARESU\nGxmD9w033ICamhq89tpr+MpXvoLXX38dP/jBD7r04jfddJPn9oQJE5yvR4wYgSeeeOLAjjbL4loc\nPlFBXNUhfjEdyvCPEd5+CJSqjgMU7gI1QTAYvImIKGcyDptHo1HceeedqK2txa233orHHnsML774\nYi6OLeeiWgw+yYeYqkNRy1C291hA9UNJUSgnJbdKNRi+iYgoNzIG73g8jvb2dui6joaGBpSXl2P7\n9u25OLaci2kx+CQz81ZkCbpuBmR3NzWbuymLJAuoCAVydpxERFTYMg6bn3XWWXj66adx7rnn4rTT\nTkNlZSVGjBiRi2PLuZgeQ7lYigZVQzCgQNV0AHDmvN3cwfvsY0alDPBERETZkDF42wVngLmka9++\nfZg4cWLWDyzXDMNATIvDJ/kQ13Qosoj2iAogc+bNGW8iIsqljOmiu7vaoEGDMGnSJCeYH0xUXYUB\nwwzeqg6fLELVzcxbSbEVqOgJ3kRERLmTMfOeOHEifvKTn2D69OlQFMW5f/bs2Vk9sFyLWq1RFVGB\nqhlQZBGaZs15KykK1kT3rmIM30RElDsZg/eHH34IAHjrrbec+wRBOOiCt92gRbY28VZkyZnzTpV5\ne4fN9RwcIRERkSlj8H788cdzcRx9zt4ONBG8RSd4y3IiUFeVB1DXGPEOmzPzJiKiHMoYvC+88MKU\nc9wrV67MygH1leTM2yeLUK1hc9k1RL78yqMQi2t4+p9bnftYsEZERLnUpQ5rtng8jvXr1yMYDGb1\noPqCvZe3CHN+293HXHb1MZcl0dkaVI8UQQyEUSQX5fBIiYio0GUM3rNmzfLcnjNnDq688sqsHVBf\nienmsLmExLC5rUM3Nfs5Hx+B4NCdOPb4g2v+n4iI+reMwTu5m9quXbvw+eefZ+2A+krMybzNU+Ju\nzCKLqZbGGTCixZD3TIFPUlJ8n4iIKDsyBu9LLrnE+VoQBJSUlOC6667L6kH1BSd4GzIA3ZN5y510\nTzv4VrwTEVF/lzF4/+Mf/4Cu6xCtoq14PO5Z732wiOnu4B3zbEYipxk2JyIi6gsZo9LatWuxePFi\n5/ZFF12El156KasH1RfsgjUY5ilxr+2WUgybc3UYERH1lYzB+9FHH8WPf/xj5/bvfvc7PProo1k9\nqL4Qt9Z5Q7fmvBV3wVr6wfGDsVUsERH1bxmDt2EYCIVCzu2SkpKDMmBFtCgAQLCCtzvzdq/ztjHx\nJiKivpJxznvKlCm44YYbMGvWLBiGgVdffRVTpkzJxbHllB287cxbUdzrvDnnTURE/UfG4H3bbbdh\nzZo12LBhAwRBwJlnnolTTjklF8eWU1HVCt6anXm7C9YOvpEGIiLKXxmDdzgchqIouP322wEATzzx\nBMLhMIqLi7N+cLlkZ96GZgZt91KxUNDX4fFVZQEAQG3VwXUeiIio/8s4Hnzrrbeivr7euR2JRHDL\nLbdk9aD6gp1566oZvH2yiOVXHolLTjkEIwaHOjz+lCOH44KTx+HKMybl9DiJiIgyBu/GxkYsWrTI\nuX3ZZZehubk5qwfVFyJaBIqoQNPM24osYsiAYhx/WG3KxyuyhHkzh6XMyomIiLIpY/COx+PYujWx\ng9bGjRsRj8ezelB9IaJFEZD8iKvWHt6ddFUjIiLqSxnnvL/3ve9h8eLFaGlpga7rqKiowL333puL\nY8upqBpFQPYjxuBNRET9XMYINW3aNKxduxarVq3CkiVLUF1djWuuuSYXx5ZTyZm3z9UelYiIqD/J\nmHm/9957WL16NV544QXouo4f/ehHmD9/fi6OLWd0Q0dUi8Ev+xHXmHkTEVH/ljZC/eY3v8Fpp52G\nG2+8EZWVlVi1ahWGDx+O008//aDbmMTuax6Q/IjHzYo1Bm8iIuqv0mbeDz74IMaOHYs77rgDRx11\nFICDt4931FrjHZADaGPmTURE/Vza4P3yyy/jT3/6E5YtWwZd13H22WcflFXmABCx1nj7JbNgTRBS\n7yRGRETUH6RNL6uqqnDVVVdh7dq1WLFiBb744gvs3LkTV199NV555ZVcHmPWOZm3VbDmk6WDdpSB\niIjyX5fGhmfOnIm7774br776Kk444QT8/Oc/z/Zx5VRYjQCAWbCm6hwyJyKifu2AolRJSQkWLFiA\np59+OlvH0ye8mbfG4E1ERP0aoxSA9ngYAFAkB9DcHkdx4OCqpiciooMLgzeAdtUM3qLuRzSmoao8\n0MdHRERElB6DN4D2eDsAIBoxT0dVeVFfHg4REVGnGLwBtFmZd6SNwZuIiPo/Bm8kMu+WVvM2gzcR\nEfVnDN5IzHk3NZnd1TjnTURE/RmDN4C2eDsUUUFLmxm8K0L+Pj4iIiKi9Bi8YQ6bFytBRGPmpiQ+\nhduBEhFR/8XgDXPYPCgXIRLX4FNEiGyNSkRE/VjBB2/d0BFWIwgqRYjFNfiZdRMRUT9X8ME7rEZg\nwECxHEQ6NoGoAAAYmElEQVSUwZuIiPIAg7ddad6so6E5Cr+PwZuIiPq3gg/eMc3co3zL9jYYADNv\nIiLq9wo+eMd1M3gbunkqGLyJiKi/Y/DWVfMLBm8iIsoTDN5W5g3dDNo+peBPCRER9XMFH6ni1pw3\nDPNUBFiwRkRE/RyDtzPnbWfeDN5ERNS/MXhzzpuIiPIMg7cz583gTURE+YHBW/MOmzN4ExFRf5fV\n4L1ixQqcf/75WLBgATZs2JDyMffffz8WLlyYzcPolDNsbhWsiSI3JSEiov4ta8H7zTffxLZt2/DU\nU09h+fLlWL58eYfHbNmyBf/5z3+ydQhdkrxUTNP0PjwaIiKizLIWvNetW4e5c+cCAMaMGYOmpia0\ntrZ6HnP33XfjxhtvzNYhdEksqcOapht9eThEREQZZS1419fXo6KiwrldWVmJuro65/bq1asxa9Ys\n1NbWZusQukR1qs3NzLu4SOnDoyEiIspMztUPMoxERtvY2IjVq1fj0UcfxZ49e7r0/IqKIGS5d4vJ\nqqpCkD63bugi5h85Al89aTwkznsfkKqqUF8fQt7jOew5nsPewfPYc7k4h1kL3tXV1aivr3du7927\nF1VVVQCA9evXY//+/bjooosQi8XwxRdfYMWKFVi6dGna12toaO/V46uqCqGurgXN7ebrGrqEuTNq\nsH9fa4Znkpt9Hqn7eA57juewd/A89lxvn8N0FwJZGzafM2cO1q5dCwDYvHkzqqurUVJSAgA45ZRT\n8MILL+Dpp5/Gz372M0yePLnTwJ1NqqvaXBILfuUcERHlgaxl3jNmzMDkyZOxYMECCIKAZcuWYfXq\n1QiFQpg3b162fuwBi7matMgSh8uJiKj/y+qc90033eS5PWHChA6PGTp0KB5//PFsHkannI1JdAmy\nxMybiIj6v4KPVqquWg1aBBaqERFRXij44B3T4xAMs4qdmTcREeWDgo9WcSt4CwJboxIRUX5g8NZU\nCKw0JyKiPFLwESuuxwFDYqU5ERHlDQZvPW4tEyv4U0FERHmioCOWYRiIaXFAl1lpTkREeaOgg7dq\naDBgwGCDFiIiyiMFHbxjWsz8QpcgcdiciIjyREFHLDt4G5rEYXMiIsobhR28rb7mhsaCNSIiyh8F\nHbHszFtn5k1ERHmkwIM3M28iIso/BR2xYnoi82a1ORER5YvCDt4sWCMiojxU4MHb3stb5FIxIiLK\nGwUdsRLrvGXOeRMRUd4o6IjlLBXTRQ6bExFR3ijs4O3qsMaCNSIiyhcM3gCgSdzPm4iI8kZBR6zE\nsLkERSnoU0FERHmkoCOWe9hcYcEaERHliYKOWFFnqZgEHzNvIiLKEwUdseJWhzWDmTcREeWRgo5Y\nMVfmrchS3x4MERFRFxV08I46c94iFLmgTwUREeWRgo5Yqq5CggRAgI/Bm4iI8kRBRyzVUCEKMgAw\n8yYiorxR0BErrschwpzrZvAmIqJ8UdARK66pruDNgjUiIsoPBR28VUOFYJingJk3ERHli4KOWKqu\nQrAybxasERFRvijoiBXXVQgG57yJiCi/FGzEMgzDzLw5bE5ERHmmYCOWqqvmF8y8iYgozxRsxIpr\nVvDW7cyb1eZERJQfCjd423t5W8PmLFgjIqJ8UbARy868DZ1z3kRElF8KNmLFrMwbmghBACRR6NsD\nIiIi6qKCDd6qlXnrugBFFiEIDN5ERJQfCjZ423t5G5oIRSrY00BERHmoYKOWXbCmaQJ8CivNiYgo\nfxRu8LaHzTWBmTcREeWVgo1acd0O3iIUpWBPAxER5aGCjVpxa85bUwXIYsGeBiIiykMFG7Xcw+ay\nzEpzIiLKH4UbvK2CNV0TITHzJiKiPFKwUcvpbW6IkCVm3kRElD8KN3jbvc11ETKrzYmIKI8UbNSy\nm7TAENkalYiI8krBBm9nP29dhMTMm4iI8kjBRq2Ys6uYBJmZNxER5ZGCDd5x97A5C9aIiCiPyNl8\n8RUrVuD999+HIAhYunQppk6d6nxv/fr1eOCBByCKIkaNGoXly5dDzOGSrYgaNb/QJBasERFRXsla\n1HrzzTexbds2PPXUU1i+fDmWL1/u+f4dd9yBhx56CE8++STa2trw6quvZutQUgqrEQCAocssWCMi\norySteC9bt06zJ07FwAwZswYNDU1obW11fn+6tWrMXjwYABAZWUlGhoasnUoKUXiZvCGJjPzJiKi\nvJK1qFVfX4+KigrndmVlJerq6pzbJSUlAIC9e/fi9ddfx/HHH5+tQ0kprEYhQLCqzZl5ExFR/sjq\nnLebYRgd7tu3bx+uvvpqLFu2zBPoU6moCEKWe2/f7XA8Ap/kRzsElJYEUFUV6rXXLjQ8dz3Hc9hz\nPIe9g+ex53JxDrMWvKurq1FfX+/c3rt3L6qqqpzbra2tuPLKK3HDDTfgmGOOyfh6DQ3tvXp8YTUC\nBQoAIBqNo66upVdfv1BUVYV47nqI57DneA57B89jz/X2OUx3IZC1YfM5c+Zg7dq1AIDNmzejurra\nGSoHgLvvvhuXXHIJjjvuuGwdQqci8QgU0QcALFgjIqK8krXMe8aMGZg8eTIWLFgAQRCwbNkyrF69\nGqFQCMcccwyee+45bNu2Dc8++ywA4IwzzsD555+frcPpIKxGUSGXAgAL1oiIKK9kdc77pptu8tye\nMGGC8/WmTZuy+aM7FddVqLoKQTffPoM3EVHfevnlv+OEE07u0mN/8pP7ce65C1BTU5vlo+q/CjJq\nRa0GLbv2xgBw2JyIqC/t2vUl/va3tV1+/PXXf7egAzeQw2rz/iSimcHb0MzqdS4VIyLqOw88cA8+\n/HAzHn30N9B1HV9+uRO7dn2JBx/8Be66607U1e1FOBzG5ZdfhTlzjsV1112F73znFvzzn39HW1sr\nvvhiG3bu3IFvf/u7mD17jvO6qqpi+fIfdHj+J598hPvvvweiKGDKlGm49trrU95n/5zRo8di1aqn\n0NjYiOnTD8eTT/4v2tvbcd11N+Ldd9/Gyy//HbquY/bsObj11u+ipaUFd955G9ra2lBSUoI77vgf\nXH75Rfj9759AMBjEhg3v4cknV2LFih93+5wVZPCOWsEb9rB5DtuyEhH1Z0//Ywv+89HeXn3NmROq\ncd5JY9N+/4ILFmL16qdx2WVX4pFHHoaqxvGLX/wWDQ37MWvWUTj11DOwc+cO3H77EsyZc6znuXv3\n7sF99z2E9ev/jT//eZUneLe0NKd8/oMP3oebb16KsWPH4Uc/ugO7d+9KeV86W7duwRNPrIbP58O7\n776NX/zitxBFEeeddxauvfabeOKJxzFr1myce+4CPPXUSrzzzls47rgT8dpr/8L8+afgtddewbx5\nX+nROS3I4G33NTc08+0z8yYi6j8mTpwMAAiFSvHhh5uxZs1qCIKI5uamDo+dOvUwAObyZHcXz86e\n/8UX2zB27DgAwO2335n2vnTGjh0Hn89crRQIBHDddVdBkiQ0NjaisbERn3zyEa644hoAwPnnXwQA\nqKmpxW9/+0vMn38K3n33bXzjG1cf+IlxKczgrSU2JQFYsEZEZDvvpLGdZsm5oChmD46//vUlNDc3\n4+c//y2am5txxRULOzxWkhLNu5KbgaV7fqpNsFLdJwiJxE5V1Q7Ht3v3Ljz11Er87ncrEQwGsXDh\nedZrSTAM3fNaY8eOw759+/Dhh5sxatQY+P3+zk9CBgUZtSL2piR25s2CNSKiPiOKIjRN63B/Y2Mj\nhgypgSiKeOWVfyAejx/Q66Z7/siRo7B5s7ni6a677sR///t5yvuKi4uxb5/ZbGzjxvdTvn5FRQWC\nwSA+/vgj7N69G/F4HBMnTsLbb/8HAPDcc6vw4ot/AQCcdNI8PPDAPZg375QDeh+pFGTwthlx88qH\nmTcRUd8ZMWIUPv74Izz00P2e+0844ST8+9+v4vrrr0FRURGqq6vx6KO/6fLrpnv+9dffhJ/97P/D\nNdd8A6FQKUaOHJXyvjPPPAf3338vbr75egwcWNXh9ceNG4+ioiCuueZy/P3v/4ezzjoHP/zhD3Hu\nuRdg06YNuO66q/Dvf7+G448/EQBw8snzsHfvXhx++MyenTAAgpGq6Xg/1Jvt5uJaHNf89hnojdWA\nIeKWC6ZjwojOe6tTamyn2HM8hz3Hc9g7eB57rrNz+Pzza7B79y584xvfPKDXS6Ug57wVSYHeMNi5\nzcybiIiy6Z57/gdffrkTd911X6+8XkEG72SsNiciomy69dbbevX1CjLl1HXvTAEL1oiIKJ8UZPCO\nxr1VjRw2JyKifFKQUSvWIXgz8yYiovxRkME7OfOW2B6ViIjySEFGrWjc2/mGmTcRUd96+eW/H/Bz\n3nvvHTQ07M/C0fR/hRm8Y0mZN+e8iYj6zIFuCWp7/vk1BRu8C3KpWMdhc2beRER9xb0l6PnnX4gV\nK36IlpYWaJqGG264GWPHjsP//u/v8cor/4Qoipgz51hMnDgJr776Mj7//DP8z//ci8GDzd4dfbEN\n6OWXX+VsAxqLReD3F2VlG1A3Bm+w2pyIyLZ6y1/w7t6Nvfqa06sPxTljz0j7ffeWoL///W9x5JFH\n4//9v6/i888/w09+ch8efPAXePLJ/8Vzz70ESZLw3HOrMHPmURg7djy+851bnMAN9M02oOeff6Gz\nDejixVfiZz/7VVa2AXVj8AabtBAR9RcbN25AY2MD1q59AQAQjZobSZ1wwsm44YbFmDfvFMyfn35j\nj77YBrS5uTkn24C6FWTwrgz54ZNF6IYBVTMgCgzeREQAcM7YMzrNkrNNUWTceOPNmDJlquf+m276\nHrZt+y/+8Y+/4lvf+iZ+/es/pHz+wbwNqOfYe+2V8sghwyvw1IrT8fBNJ+DXN5/Q14dDRFTQ3FuC\nTpo0Bf/618sAgM8//wxPPvm/aG1txaOP/gYjRozEZZddiVCoDO3tbSm3Ej2YtwH1nLNefbU8Iksi\nBEHgfDcRUR9zbwn69a+fj507t2Px4itwzz3/g8MOm4GSkhI0NjbgyisX4dvfvhqTJ09BaWkZDjts\nBm677VZ89tlW57X6YhvQ+++/x9kGdOHChVnbBtStILcEBbj1XW/heew5nsOe4znsHTyPPZd8Druz\nDWjy66VSkHPeRERE2dbb24C6MXgTERFlQW9vA+rGCV8iIqI8w+BNRESUZxi8iYiI8gyDNxERUZ5h\n8CYiIsozDN5ERER5hsGbiIgozzB4ExER5Zm8aY9KREREJmbeREREeYbBm4iIKM8weBMREeUZBm8i\nIqI8w+BNRESUZxi8iYiI8kxB7ue9YsUKvP/++xAEAUuXLsXUqVP7+pD6tU8++QSLFy/GpZdeiosv\nvhi7du3CLbfcAk3TUFVVhR//+Mfw+XxYs2YN/vCHP0AURZx33nk499xz+/rQ+417770Xb7/9NlRV\nxTe/+U0ceuihPIcHIBwOY8mSJdi3bx+i0SgWL16MCRMm8Bx2UyQSwRlnnIHFixdj9uzZPI8H4I03\n3sD111+PcePGAQDGjx+PK664Ivfn0Cgwb7zxhnHVVVcZhmEYW7ZsMc4777w+PqL+ra2tzbj44ouN\n2267zXj88ccNwzCMJUuWGC+88IJhGIZx//33GytXrjTa2tqM+fPnG83NzUY4HDZOP/10o6GhoS8P\nvd9Yt26dccUVVxiGYRj79+83jj/+eJ7DA/T8888bv/71rw3DMIwdO3YY8+fP5znsgQceeMA455xz\njFWrVvE8HqD169cb3/rWtzz39cU5LLhh83Xr1mHu3LkAgDFjxqCpqQmtra19fFT9l8/nw29+8xtU\nV1c7973xxhs4+eSTAQAnnngi1q1bh/fffx+HHnooQqEQAoEAZsyYgXfeeaevDrtfmTlzJn7yk58A\nAEpLSxEOh3kOD9Bpp52GK6+8EgCwa9cuDBo0iOewm7Zu3YotW7bghBNOAMD/z72hL85hwQXv+vp6\nVFRUOLcrKytRV1fXh0fUv8myjEAg4LkvHA7D5/MBAAYMGIC6ujrU19ejsrLSeQzPa4IkSQgGgwCA\nZ599FscddxzPYTctWLAAN910E5YuXcpz2E333HMPlixZ4tzmeTxwW7ZswdVXX40LLrgAr7/+ep+c\nw4Kc83Yz2B22R9KdP57Xjv72t7/h2Wefxe9+9zvMnz/fuZ/nsOuefPJJfPjhh7j55ps954fnsGue\ne+45HHbYYRg2bFjK7/M8ZjZy5Ehcd911OPXUU7F9+3YsWrQImqY538/VOSy44F1dXY36+nrn9t69\ne1FVVdWHR5R/gsEgIpEIAoEA9uzZg+rq6pTn9bDDDuvDo+xfXn31VfzqV7/Cb3/7W4RCIZ7DA7Rp\n0yYMGDAAQ4YMwcSJE6FpGoqLi3kOD9DLL7+M7du34+WXX8bu3bvh8/n4t3iABg0ahNNOOw0AMHz4\ncAwcOBAbN27M+TksuGHzOXPmYO3atQCAzZs3o7q6GiUlJX18VPnl6KOPds7h//3f/+HYY4/FtGnT\nsHHjRjQ3N6OtrQ3vvPMOjjjiiD4+0v6hpaUF9957Lx5++GGUl5cD4Dk8UG+99RZ+97vfATCnvtrb\n23kOu+HBBx/EqlWr8PTTT+Pcc8/F4sWLeR4P0Jo1a/DII48AAOrq6rBv3z6cc845OT+HBbmr2H33\n3Ye33noLgiBg2bJlmDBhQl8fUr+1adMm3HPPPdi5cydkWcagQYNw3333YcmSJYhGo6ipqcFdd90F\nRVHw0ksv4ZFHHoEgCLj44otx5pln9vXh9wtPPfUUfvrTn2LUqFHOfXfffTduu+02nsMuikQi+P73\nv49du3YhEonguuuuw5QpU3DrrbfyHHbTT3/6U9TW1uKYY47heTwAra2tuOmmm9Dc3Ix4PI7rrrsO\nEydOzPk5LMjgTURElM8KbticiIgo3zF4ExER5RkGbyIiojzD4E1ERJRnGLyJiIjyTME1aSHKN/fe\ney82btyIaDSKDz74ANOnTwcAfO1rX8NXv/rVLr3Gr3/9a4wfP97pZ53KwoUL8fvf/x6SJPXGYXvs\n2bMHn332GWbPnt3rr01UiLhUjChP7NixAxdeeCH+9a9/9fWhHLA1a9Zg69atuPHGG/v6UIgOCsy8\nifLYT3/6U+zYsQNffvklbr31VkQiEdx3333w+XyIRCJYtmwZJk+ejCVLluDwww/H7Nmzcc011+CY\nY47Bhg0b0NbWhocffhiDBg3CIYccgs2bN+OXv/wlGhsbsXv3bmzbtg1HHnkkbr/9dkSjUdx6663Y\nuXMnBg8eDEmSMGfOHM8exW1tbfjud7+L5uZmqKqKE088EWeccQYefPBBGIaB8vJyXHTRRbjzzjux\nbds2tLW14YwzzsDll1+O1atX469//SsEQcCePXswevRorFixAoqi9OEZJuqfOOdNlOd27NiBxx57\nDFOmTEFjYyN+8IMf4LHHHsOiRYvw8MMPd3j81q1bcc4552DlypWYOHEiXnzxxQ6P+eCDD/DQQw/h\n2WefxerVq9HU1IQ1a9ZAVVU888wzuOOOO/D66693eN6///1vqKqKP/7xj3jyyScRDAZRW1uLs88+\nG2eeeSYuu+wyPPbYY6iursbjjz+OZ555Bs8//zw++ugjAMDGjRv///bu2CW1MIzj+NcONQQRQi3W\nYnBsjDoSBFKNOVaEo0M4REO4HGyrKQin5ob+gDBaoiVyECEipakhWkKkQKFoiERPd5DOzYxLlysX\njvw+4+F5X97tx/PyHh7S6TSHh4eUy2VP3jKI/A/qvEU8bmJiAp/PB8DQ0BC7u7u8vb3x8vLC4OBg\nW73f78c0TQACgQBPT09tNZZlYRgGhmHg9/t5fn7m5uaG6elpAIaHh7Esq23d1NQUe3t7bGxsMDc3\nx8rKCj09rT3CxcUFDw8PXF5eAlCr1bi/v3fXf4xPnZyc5O7uzp2TLCK/KbxFPO7ztbJt22xvbzMz\nM8P5+bk7zOOzrw/Svnv28l2N4zgtQfw1lKE5y/j4+JhiscjZ2RnLy8scHR211PT19bG+vs7CwkLL\n90wmg+M4fzyXiDTp2lyki1QqFUzTpNFocHp6Sq1W69jeY2NjFItFAKrVKldXV201uVyObDaLZVnY\ntk1/fz/VahWfz0e9XgeaXf3HVb3jOOzs7Ljd//X1Na+vr7y/v1MoFBgfH+/Y+UW6iTpvkS6SSCSI\nx+MEAgFWV1exbZuDg4OO7L20tEQ2myUWizE6Oko4HG7r0IPBIKlUiv39fQzDIBKJMDIyQjgcJplM\n0tvby9raGre3t8RiMRqNBvPz8+6o1FAoxObmJqVSCdM0iUQiHTm7SLfRr2Ii8iOPj48UCgWi0SiO\n47C4uMjW1pb73/m/ymQy5PN50ul0R/YT6WbqvEXkRwYGBjg5OXHnE8/OznYsuEXk76jzFhER8Rg9\nWBMREfEYhbeIiIjHKLxFREQ8RuEtIiLiMQpvERERj1F4i4iIeMwvRph4T/csGFUAAAAASUVORK5C\nYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "metadata": { + "id": "HNqUFL4deCsL", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 4. Case study: building an RNN\n" + ] + }, + { + "metadata": { + "id": "YkC1k4HEQ7rw", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode." + ] + }, + { + "metadata": { + "id": "7nkPDl5CTCNb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant." + ] + }, + { + "metadata": { + "id": "A0uREmVXCQEw", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def parse(line):\n", + " \"\"\"Parses a line from the colors dataset.\n", + " \n", + " Args:\n", + " line: A comma-separated string containing four items:\n", + " color_name, red, green, and blue, representing the name and\n", + " respectively the RGB value of the color, as an integer\n", + " between 0 and 255.\n", + "\n", + " Returns:\n", + " A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),\n", + " (batch_size, max_sequence_length, 256) and respectively (batch_size).\n", + " \"\"\"\n", + " items = tf.string_split([line], \",\").values\n", + " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", + " color_name = items[0]\n", + " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", + " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", + " return rgb, chars, length\n", + "\n", + "\n", + "def maybe_download(filename, work_directory, source_url):\n", + " \"\"\"Downloads the data from source url.\"\"\"\n", + " if not tf.gfile.Exists(work_directory):\n", + " tf.gfile.MakeDirs(work_directory)\n", + " filepath = os.path.join(work_directory, filename)\n", + " if not tf.gfile.Exists(filepath):\n", + " temp_file_name, _ = six.moves.urllib.request.urlretrieve(source_url)\n", + " tf.gfile.Copy(temp_file_name, filepath)\n", + " with tf.gfile.GFile(filepath) as f:\n", + " size = f.size()\n", + " print('Successfully downloaded', filename, size, 'bytes.')\n", + " return filepath\n", + "\n", + "\n", + "def load_dataset(data_dir, url, batch_size, training=True):\n", + " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", + " path = maybe_download(os.path.basename(url), data_dir, url)\n", + " dataset = tf.data.TextLineDataset(path)\n", + " dataset = dataset.skip(1)\n", + " dataset = dataset.map(parse)\n", + " dataset = dataset.cache()\n", + " dataset = dataset.repeat()\n", + " if training:\n", + " dataset = dataset.shuffle(buffer_size=3000)\n", + " dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))\n", + " return dataset\n", + "\n", + "\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "data_dir = \"tmp/rnn/data\"" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "waZ89t3DTUla", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.\n", + "\n", + "Autograph doesn't fully support classes yet (but it will soon!), so we'll write the model using simple functions." + ] + }, + { + "metadata": { + "id": "9v8AJouiC44V", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def model_components():\n", + " lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", + " lower_cell.build(tf.TensorShape((None, 256)))\n", + " upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " upper_cell.build(tf.TensorShape((None, 256)))\n", + " relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", + " relu_layer.build(tf.TensorShape((None, 128)))\n", + " return lower_cell, upper_cell, relu_layer\n", + "\n", + "\n", + "def rnn_layer(chars, cell, batch_size, training):\n", + " \"\"\"A simple RNN layer.\n", + " \n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " n = tf.shape(chars)[0]\n", + " i = 0\n", + " while i < n:\n", + " ch = chars[i]\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " i += 1\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + "\n", + "def model(inputs, lower_cell, upper_cell, relu_layer, batch_size, training):\n", + " \"\"\"RNNColorbot model.\n", + " \n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", + " \n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " lower_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " upper_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " relu_layer: An object of type tf.layers.Dense\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + " \n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " (chars, length) = inputs\n", + " chars_time_major = tf.transpose(chars, [1, 0, 2])\n", + " chars_time_major.set_shape((None, batch_size, 256))\n", + "\n", + " hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)\n", + " final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)\n", + "\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(final_outputs, indices)\n", + " return relu_layer(sequence_ends)\n", + "\n", + "def loss_fn(labels, predictions):\n", + " return tf.reduce_mean((predictions - labels) ** 2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "JjK4gXFvFsf4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch." + ] + }, + { + "metadata": { + "id": "ZWQMExk0S6X6", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " iterator = train_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)\n", + " loss = loss_fn(labels, predictions)\n", + " optimizer.minimize(loss)\n", + " if step % (num_steps // 10) == 0:\n", + " print('Step', step, 'train loss', loss)\n", + " step += 1\n", + " return step\n", + "\n", + "\n", + "def test(eval_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " total_loss = 0.0\n", + " iterator = eval_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)\n", + " total_loss += loss_fn(labels, predictions)\n", + " step += 1\n", + " print('Test loss', total_loss)\n", + " return total_loss\n", + "\n", + "\n", + "def train_model(train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps):\n", + " optimizer = tf.train.AdamOptimizer(learning_rate=0.01)\n", + "\n", + " train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps=tf.constant(train_steps))\n", + " test(eval_data, lower_cell, upper_cell, relu_layer, 50, num_steps=tf.constant(2))\n", + "\n", + " print('Colorbot is ready to generate colors!\\n\\n')\n", + " \n", + " # In graph mode, every op needs to be a dependent of another op.\n", + " # Here, we create a no_op that will drive the execution of all other code in\n", + " # this function. Autograph will add the necessary control dependencies.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "iopcs5hXG2od", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we add code to run inference on a single input, which we'll read from the input.\n", + "\n", + "Note the `do_not_convert` annotation that lets us disable conversion for certain functions and run them as a `py_func` instead, so you can still call them from compiled code." + ] + }, + { + "metadata": { + "id": "DyU0wnnAFEYj", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)\n", + "def draw_prediction(color_name, pred):\n", + " pred = pred * 255\n", + " pred = pred.astype(np.uint8)\n", + " plt.axis('off')\n", + " plt.imshow(pred)\n", + " plt.title(color_name)\n", + " plt.show()\n", + "\n", + "\n", + "def inference(color_name, lower_cell, upper_cell, relu_layer):\n", + " _, chars, sequence_length = parse(color_name)\n", + " chars = tf.expand_dims(chars, 0)\n", + " sequence_length = tf.expand_dims(sequence_length, 0)\n", + " pred = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, 1, training=False)\n", + " pred = tf.minimum(pred, 1.0)\n", + " pred = tf.expand_dims(pred, 0)\n", + " draw_prediction(color_name, pred)\n", + " # Create an op that will drive the entire function.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Nt0Kv5OCHip0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we put everything together.\n", + "\n", + "Note that the entire training and testing code is all compiled into a single op (`tf_train_model`) that you only execute once! We also still use a `sess.run` loop for the inference part, because that requires keyboard input." + ] + }, + { + "metadata": { + "id": "-GmWa0GtYWdh", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 668 + }, + "outputId": "61f4af1d-c81e-44db-9079-1a7b8ed8ce58", + "executionInfo": { + "status": "ok", + "timestamp": 1522345877153, + "user_tz": 240, + "elapsed": 75500, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", + " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", + "\n", + " tb = widgets.TabBar([\"RNN Colorbot\"])\n", + " while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " sess.run(inference_ops, {color_name_placeholder: color_name})\n", + " plt.show()\n", + "\n", + "with tf.Graph().as_default():\n", + " # Read the data.\n", + " batch_size = 64\n", + " train_data = load_dataset(data_dir, train_url, batch_size)\n", + " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", + " \n", + " # Create the model components.\n", + " lower_cell, upper_cell, relu_layer = model_components()\n", + " # Create the helper placeholder for inference.\n", + " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", + " \n", + " # Compile the train / test code.\n", + " tf_train_model = autograph.to_graph(train_model)\n", + " train_model_ops = tf_train_model(\n", + " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", + " \n", + " # Compile the inference code.\n", + " tf_inference = autograph.to_graph(inference)\n", + " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", + " \n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " \n", + " # Run training and testing.\n", + " sess.run(train_model_ops)\n", + " \n", + " # Run the inference loop.\n", + " run_input_loop(sess, inference_ops, color_name_placeholder)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('Successfully downloaded', 'train.csv', 28010L, 'bytes.')\n", + "('Successfully downloaded', 'test.csv', 2414L, 'bytes.')\n", + "Step 0 train loss 0.37890616\n", + "Step 10 train loss 0.18515904\n", + "Step 20 train loss 0.0892782\n", + "Step 30 train loss 0.07883155\n", + "Step 40 train loss 0.08585831\n", + "Step 50 train loss 0.09302989\n", + "Step 60 train loss 0.089012615\n", + "Step 70 train loss 0.07275697\n", + "Step 80 train loss 0.06644974\n", + "Step 90 train loss 0.0854013\n", + "Test loss 0.13216865Colorbot is ready to generate colors!\n", + "\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "
" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b102d936-3379-11e8-ac70-0242ac110002\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"borderColor\": [\"#a7a7a7\"], \"tabNames\": [\"RNN Colorbot\"], \"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"elementId\": \"id1\"});\n", + "//# sourceURL=js_e223a56194" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b103532a-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_b8c6a821fb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b105b28c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_44805e254b" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b106197a-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_a63d3c6c47" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1069f44-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b106197a-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7e203b8bce" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1070f38-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_d53293d4a7" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6d90d5c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b105b28c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_3000dc2c05" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6da872c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_4136f669a3" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dac868-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_2f70dd9aee" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6db07d8-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6dac868-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7226726048" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dcc6fe-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_72e7709865" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFZCAYAAADHDNdrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9JJREFUeJzt3E1Lle0ax+HTF4jeEAyMBhE0DawI\nwsCH0AIlaGBWNJBo0CDoA0TQhmDXuKAGDioiCA2KlEAlnl05FD9Co8BeaGCQoBDa2jPZsXt4Bvu/\n0+o4Rmvd1zW4rsmP84bFamo0Go0C4H/WvNYHAPhVCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKDy\nUxgeHq5Dhw7V4OBgPXz4sHp7e+vWrVt15cqVOnnyZN2/f78ajUbdvn27+vr6qqenp65du1YrKytV\nVfXhw4e6cOFC9fX1VV9fX01PT1dV1dzcXHV3d9eDBw/q+PHj9ccff9TExMRaXpWfWOtaHwD+zuvX\nr+vOnTs1MTFRbW1tdf78+dW16enpGh8fr/b29hobG6upqal6/Phxbdy4sS5evFgjIyM1NDRUly5d\nqv3799fw8HC9efOmTp8+XVNTU1VV9enTp2pubq5nz57V5ORk3bhxo44dO7ZW1+UnZkJl3Zudna2D\nBw9WR0dHbdiwoQYHB1fX9u7dW+3t7VVV9fLlyxocHKytW7dWa2trnTp1qp4/f16Li4s1MzNT586d\nq6qqXbt21YEDB1an1OXl5Tpx4kRVVe3Zs6fevXv3Yy/IL8OEyrr3+fPnamtrW/2+ffv21c//+Xxh\nYaHu3r1bjx49qqqqlZWVam9vr4WFhWo0GnXmzJnVvYuLi9XV1VVVVS0tLbVp06aqqmpubq6vX7/+\nX+/Dr0tQWfe2bNlSi4uLq98/fvz43X0dHR3V29tbQ0ND3zxfXl6ulpaWevLkSW3evPmbtbm5ufyB\n+W155Wfd6+zsrJmZmZqfn68vX77U2NjYd/cdOXKkxsfHa2lpqaqqRkdH6+nTp9Xa2lqHDx+u0dHR\nqqpaWlqqy5cv1/v373/YHfg9CCrrXmdnZw0MDNTAwECdPXu2enp6vrvv6NGj1dPTUwMDA9Xf318v\nXryo7u7uqqq6evVqzc7OVn9/fw0MDNTOnTtrx44dP/Ia/Aaa/B8qP4NGo1FNTU1VVfXq1au6efPm\nX06qsFZMqKx78/Pz1dXVVW/fvq1Go1GTk5O1b9++tT4W/BcTKj+FkZGRunfvXjU1NdXu3bvr+vXr\ntW3btrU+FnxDUAFCvPIDhAgqQMi6+WH/kX8eXesjAPytf/3jz79cM6EChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkBI\nU6PRaKz1IQB+BSZUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBB\nBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAkH8D1Aj8lNhhe7QAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70592aa-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6da872c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_25c3aaf79a" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70842c0-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_984c56b816" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c708dec4-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_e0451a1217" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7092726-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c708dec4-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7aa23d7385" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7099044-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_5722756ddb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "stream", + "text": [ + "Give me a color name (or press 'enter' to exit): \n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7baac12-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c70842c0-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_cdd622e58f" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + } + ] + }, + { + "metadata": { + "id": "AHJ2c47U-A5W", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Where do we go next?\n", + "\n", + "Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!" + ] + } + ] +} diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD similarity index 74% rename from tensorflow/contrib/py2tf/impl/BUILD rename to tensorflow/contrib/autograph/impl/BUILD index 90ffabbc9bf4524ec2ebf54b6dd847bd8768a486..54424e26472b8466b8fe68ea848b5463c10224c9 100644 --- a/tensorflow/contrib/py2tf/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -25,10 +25,11 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -38,10 +39,12 @@ py_test( name = "api_test", srcs = ["api_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", + "//third_party/py/numpy", ], ) @@ -49,6 +52,7 @@ py_test( name = "conversion_test", srcs = ["conversion_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/py2tf/impl/api.py b/tensorflow/contrib/autograph/impl/api.py similarity index 76% rename from tensorflow/contrib/py2tf/impl/api.py rename to tensorflow/contrib/autograph/impl/api.py index 883b304089024363f41cabde2cb74c49f01ae836..dce994e50df60d8bd419f62207d77035beac9f5a 100644 --- a/tensorflow/contrib/py2tf/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -20,15 +20,20 @@ from __future__ import print_function from functools import wraps +from enum import Enum + +# pylint:disable=g-bad-import-order import gast import six - -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.impl import conversion -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import inspect_utils -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.utils import builtins +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -37,55 +42,6 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def graph_ready(f): - """No-op decorator that explicitly marks a function as graph-ready. - - Graph-ready functions are assumed to not need any conversion. - - Args: - f: Any callable. - Returns: - f itself. - """ - setattr(f, '__pyct_is_compile_decorator', True) - return f - - -def convert_inline(f, *args, **kwargs): - """Shorthand to convert and call a function. - - For example, the following two statements are equivalent: - - @convert() - def foo(): - ... - foo(bar) - - def foo(): - ... - convert_inline(foo, bar) - - Args: - f: Function to convert. Only this call will be converted. - *args: Passed through to f. - **kwargs: Passed through to f, with the following exceptions: - * arg_value_hints: A dict mapping parameter names to objects that can - hint at the type of those parameters. - - Returns: - The result of the converted f applied to args and kwargs. - """ - if 'arg_value_hints' in kwargs: - arg_value_hints = kwargs['arg_value_hints'] - del kwargs['arg_value_hints'] - else: - arg_value_hints = None - if tf_inspect.ismethod(f): - # When converting methods, the result is still an unbound function. - args = (f.__self__,) + args - return convert(arg_value_hints)(f)(*args, **kwargs) - - def convert(recursive=False, verbose=False, arg_types=None): """Decorator that compiles a function to graph mode. @@ -122,6 +78,55 @@ def convert(recursive=False, verbose=False, arg_types=None): return decorator +class RunMode(Enum): + GRAPH = 1 + PY_FUNC = 2 + + +def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): + """Decorator that suppresses compilation of a function. + + Args: + run_as: RunMode value. Whether to run the function as-is, or wrap it into + a py_func. + return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or + empty list or tuple will create a dummy return value that can be used + to set control dependencies. + + Returns: + A decorator that wraps the original function. + """ + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def graph_wrapper(*args, **kwargs): + return f(*args, **kwargs) + + @wraps(f) + def py_func_wrapper(*args, **kwargs): + if kwargs: + raise NotImplementedError( + 'RunMode.PY_FUNC does not yet support kwargs') + # TODO(mdan): Add support for kwargs. + return py_func.wrap_py_func( + f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) + + if run_as == RunMode.GRAPH: + wrapper = graph_wrapper + elif run_as == RunMode.PY_FUNC: + wrapper = py_func_wrapper + else: + raise ValueError('unknown value for run_as: %s' % run_as) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. @@ -227,7 +232,7 @@ def to_graph(e, """ conversion_map = conversion.ConversionMap( recursive=recursive, - nocompile_decorators=(convert, graph_ready, convert_inline), + nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) @@ -242,7 +247,10 @@ def to_graph(e, # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): - compiled_node.__dict__.update(inspect_utils.getnamespace(e)) + for key, val in inspect_utils.getnamespace(e).items(): + # Avoid overwriting entities that have been transformed. + if key not in compiled_node.__dict__: + compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: @@ -274,7 +282,7 @@ def to_code(e, """ conversion_map = conversion.ConversionMap( recursive=recursive, - nocompile_decorators=(convert, graph_ready, convert_inline), + nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) diff --git a/tensorflow/contrib/py2tf/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py similarity index 81% rename from tensorflow/contrib/py2tf/impl/api_test.py rename to tensorflow/contrib/autograph/impl/api_test.py index 13f8e66018920a5b13f8bd3f00c67d3bbdd519aa..ee2d301d7562ef5ba6bc7ca6d013b99dec78d4c3 100644 --- a/tensorflow/contrib/py2tf/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.impl import api -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.pyct import parser +import numpy as np + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -34,10 +37,8 @@ class ApiTest(test.TestCase): def setUp(self): config.COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils', - 'tf = py2tf_utils.fake_tf()' - ) + 'from tensorflow.contrib.autograph import utils as ' + 'autograph_utils', 'tf = autograph_utils.fake_tf()') def test_decorator_recurses(self): @@ -81,11 +82,11 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_converted(self): + def test_decorator_calls_unconverted_graph(self): class TestClass(object): - @api.graph_ready + @api.do_not_convert(api.RunMode.GRAPH) def called_member(self, a): return tf.negative(a) @@ -102,20 +103,23 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_decorated(self): + def test_decorator_calls_unconverted_py_func(self): class TestClass(object): - @api.convert() + @api.do_not_convert( + api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1)) def called_member(self, a): - if a < 0: - a = -a - return a + return np.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= self.called_member(a) + y = self.called_member(a) + # set_shape works around while_loop's limitations. + # TODO(mdan): Allow specifying shapes (or ShapeLike) instead. + y.set_shape(a.shape) + x //= y return x tc = TestClass() @@ -125,10 +129,11 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_convert_call_site_decorator(self): + def test_decorator_calls_decorated(self): class TestClass(object): + @api.convert() def called_member(self, a): if a < 0: a = -a @@ -137,7 +142,7 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.convert_inline(self.called_member, a) + x //= self.called_member(a) return x tc = TestClass() @@ -147,17 +152,20 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_graph_ready_call_site_decorator(self): + def test_convert_call_site_decorator(self): class TestClass(object): def called_member(self, a): - return tf.negative(a) + if a < 0: + a = -a + return a @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.graph_ready(self.called_member(a)) + x //= api.converted_call(self.called_member, False, False, {}, self, + a) return x tc = TestClass() @@ -168,6 +176,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_to_graph_basic(self): + def test_fn(x, s): while tf.reduce_sum(x) > s: x //= 2 @@ -180,6 +189,7 @@ class ApiTest(test.TestCase): self.assertListEqual([1, 2], sess.run(x).tolist()) def test_to_code_basic(self): + def test_fn(x, s): while tf.reduce_sum(x) > s: x /= 2 @@ -188,7 +198,7 @@ class ApiTest(test.TestCase): compiled_code = api.to_code(test_fn) # Just check for some key words and that it is parseable Python code. - self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while') + self.assertRegexpMatches(compiled_code, 'autograph_utils\\.run_while') self.assertIsNotNone(parser.parse_str(compiled_code)) diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/autograph/impl/config.py similarity index 73% rename from tensorflow/contrib/py2tf/impl/config.py rename to tensorflow/contrib/autograph/impl/config.py index bdbc6663dd65ed66c55ad2d2e52428084bbea219..26326465e265f5b40c3badedc0ea2813248ef60f 100644 --- a/tensorflow/contrib/py2tf/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.autograph import utils PYTHON_LITERALS = { @@ -35,16 +35,21 @@ DEFAULT_UNCOMPILED_MODULES = set(( # All of tensorflow's subpackages. Unlike the root tf module, they don't # have well-known names. Not refering to the module directly to avoid # circular imports. - (utils.__name__[:-len('.contrib.py2tf.utils')],), + ( + utils.__name__[:-len('.contrib.autograph.utils')],), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) -# TODO(mdan): Also allow controlling the generated names (for testability). +# TODO(mdan): Also allow controlling the generated names. +# TODO(mdan); Consolidate all internal imports into a single __ag module. COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', 'import tensorflow as tf', - 'from tensorflow.contrib.py2tf.impl import api as ' - 'py2tf_api', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') + 'from tensorflow.contrib.autograph.impl import api' + ' as autograph_api', + 'from tensorflow.contrib.autograph import utils' + ' as autograph_utils', + 'from tensorflow.contrib.autograph import operators' + ' as __ops', +) diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py similarity index 84% rename from tensorflow/contrib/py2tf/impl/conversion.py rename to tensorflow/contrib/autograph/impl/conversion.py index 37b24ab55fdd1b03e12e9afe06530e3c26218b61..62a49cd92d835fb942f48354041cb0ab03d02c97 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -20,31 +20,31 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.converters import for_loops -from tensorflow.contrib.py2tf.converters import ifexp -from tensorflow.contrib.py2tf.converters import lists -from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.converters import name_scopes -from tensorflow.contrib.py2tf.converters import side_effect_guards -from tensorflow.contrib.py2tf.converters import single_return -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.impl import naming -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import inspect_utils -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info -from tensorflow.contrib.py2tf.utils import type_hints +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.converters import for_loops +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info +from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect @@ -213,19 +213,19 @@ def class_to_graph(c, conversion_map): def _add_self_references(namespace, api_module): """Self refs are only required for analysis and are not used directly.""" # Manually add the utils namespace which may be used from generated code. - if 'py2tf_util' not in namespace: - namespace['py2tf_utils'] = utils - elif namespace['py2tf_utils'] != utils: + if 'autograph_util' not in namespace: + namespace['autograph_utils'] = utils + elif namespace['autograph_utils'] != utils: raise ValueError( - 'The module name "py2tf_utils" is reserved and may not be used.') + 'The module name "autograph_utils" is reserved and may not be used.') # We also make reference to the api module for dynamic conversion, but # to avoid circular references we don't import it here. - if 'py2tf_api' not in namespace: - namespace['py2tf_api'] = api_module - elif namespace['py2tf_api'] != api_module: + if 'autograph_api' not in namespace: + namespace['autograph_api'] = api_module + elif namespace['autograph_api'] != api_module: raise ValueError( - 'The module name "py2tf_api" is reserved and may not be used.') + 'The module name "autograph_api" is reserved and may not be used.') def function_to_graph(f, conversion_map, arg_values, arg_types, diff --git a/tensorflow/contrib/py2tf/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py similarity index 96% rename from tensorflow/contrib/py2tf/impl/conversion_test.py rename to tensorflow/contrib/autograph/impl/conversion_test.py index 9ff256aace7a0e7ac5e7ac07e580b8bed7d8df6f..7066739eb87f89ab98e906b10dab62baeaa2de8e 100644 --- a/tensorflow/contrib/py2tf/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.impl import conversion +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/impl/naming.py b/tensorflow/contrib/autograph/impl/naming.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming.py rename to tensorflow/contrib/autograph/impl/naming.py index 51326091de13715c32d0a79279f1d3274e48ad10..1facaa0ca0ebcc6d4281e7c92a462ceeb00b453a 100644 --- a/tensorflow/contrib/py2tf/impl/naming.py +++ b/tensorflow/contrib/autograph/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import qual_names class Namer(object): diff --git a/tensorflow/contrib/py2tf/impl/naming_test.py b/tensorflow/contrib/autograph/impl/naming_test.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming_test.py rename to tensorflow/contrib/autograph/impl/naming_test.py index beb4e54937bbb91b19157c9b9e3c528353206c62..73fc0894655cb49e4f61bf8ca51995b06feb3072 100644 --- a/tensorflow/contrib/py2tf/impl/naming_test.py +++ b/tensorflow/contrib/autograph/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import naming +from tensorflow.contrib.autograph.impl import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7856c253bd0c83b1712267184393a8742576bfcd --- /dev/null +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "operators", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [], +) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/autograph/operators/__init__.py similarity index 62% rename from tensorflow/contrib/bayesflow/python/ops/hmc.py rename to tensorflow/contrib/autograph/operators/__init__.py index c8a5a195d3d709ded7afd09287255deab2ac2f3c..c3f4cab69eed416ed5f4987076969de9c353c203 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" +"""This module implements operators that we overload. + +Note that "operator" is used loosely here, and includes control structures like +conditionals and loops, implemented in functional form, using for example +closures for the body. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -# go/tf-wildcard-import -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member -from tensorflow.python.util import all_util - -_allowed_symbols = [ - "sample_chain", - "kernel", -] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD similarity index 98% rename from tensorflow/contrib/py2tf/pyct/BUILD rename to tensorflow/contrib/autograph/pyct/BUILD index edec5f7712d08247437c9e95d743e59dafffcd7b..c483ff68c4b7c6d9a3315f569b62b8f253079f00 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -66,6 +66,7 @@ py_test( name = "compiler_test", srcs = ["compiler_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/py2tf/pyct/__init__.py b/tensorflow/contrib/autograph/pyct/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/__init__.py rename to tensorflow/contrib/autograph/pyct/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/anno.py rename to tensorflow/contrib/autograph/pyct/anno.py diff --git a/tensorflow/contrib/py2tf/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py similarity index 97% rename from tensorflow/contrib/py2tf/pyct/anno_test.py rename to tensorflow/contrib/autograph/pyct/anno_test.py index 6c29918fdfaaa0224f20a2c3cb2ea8088f3eb52b..1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d 100644 --- a/tensorflow/contrib/py2tf/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py similarity index 87% rename from tensorflow/contrib/py2tf/pyct/ast_util.py rename to tensorflow/contrib/autograph/pyct/ast_util.py index f916775b9cf3cec960ec2896c334f1d737862205..4f76a695228f7d84b80b2e4b03801e15e94b8f11 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -22,7 +22,7 @@ import ast import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno class CleanCopier(gast.NodeVisitor): @@ -84,7 +84,10 @@ class SymbolRenamer(gast.NodeTransformer): return self._process(node) def visit_Attribute(self, node): - return self._process(node) + if anno.hasanno(node, anno.Basic.QN): + return self._process(node) + # Attributes of dynamic objects will not have a QN. + return self.generic_visit(node) def rename_symbols(node, name_map): @@ -94,3 +97,12 @@ def rename_symbols(node, name_map): elif isinstance(node, tuple): return tuple(renamer.visit(n) for n in node) return renamer.visit(node) + + +def keywords_to_dict(keywords): + keys = [] + values = [] + for kw in keywords: + keys.append(gast.Str(kw.arg)) + values.append(kw.value) + return gast.Dict(keys=keys, values=values) diff --git a/tensorflow/contrib/py2tf/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py similarity index 78% rename from tensorflow/contrib/py2tf/pyct/ast_util_test.py rename to tensorflow/contrib/autograph/pyct/ast_util_test.py index a871ccad6fc7ea1487e41fd6da3ce6120bdcbcbd..8faf92c705d997db298dbb1115981fd9da26372d 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.python.platform import test @@ -74,6 +76,17 @@ class AstUtilTest(test.TestCase): self.assertFalse(ret is new_node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_keywords_to_dict(self): + keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords + d = ast_util.keywords_to_dict(keywords) + # Make sure we generate a usable dict node by attaching it to a variable and + # compiling everything. + output = parser.parse_str('b = 3') + output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) + result, _ = compiler.ast_to_object(output) + self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + print(d) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py similarity index 98% rename from tensorflow/contrib/py2tf/pyct/compiler.py rename to tensorflow/contrib/autograph/pyct/compiler.py index 507dbc7ed3de9c0b8874164e97a3d1d149e42423..24c4517afa89147101f80af3ef60237132c1144c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -31,7 +31,7 @@ import astor import gast -def ast_to_source(node, indentation): +def ast_to_source(node, indentation=' '): """Return the source code of given AST.""" if isinstance(node, gast.AST): node = gast.gast_to_ast(node) diff --git a/tensorflow/contrib/py2tf/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/compiler_test.py rename to tensorflow/contrib/autograph/pyct/compiler_test.py index 243f4c81538f5853a01ff444f2ff16ccf7cd5d62..98cdc1506b6aced603df99662f1468687a55f92c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -22,8 +22,8 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test from tensorflow.python.util import tf_inspect diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/context.py rename to tensorflow/contrib/autograph/pyct/context.py diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/inspect_utils.py rename to tensorflow/contrib/autograph/pyct/inspect_utils.py diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py similarity index 98% rename from tensorflow/contrib/py2tf/pyct/inspect_utils_test.py rename to tensorflow/contrib/autograph/pyct/inspect_utils_test.py index 5528ac851f74bd7b7dacdbe7b930945afa8c9783..ddca6f963b8abadd621c544a79935c69326bf65e 100644 --- a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -22,7 +22,7 @@ from functools import wraps import six -from tensorflow.contrib.py2tf.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py similarity index 64% rename from tensorflow/contrib/py2tf/pyct/parser.py rename to tensorflow/contrib/autograph/pyct/parser.py index dc7df883b349becd860bb0dbceab22cb39c750b5..c961efa892df6a21804dae8f52ef64bf99cd409e 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -29,12 +29,30 @@ from tensorflow.python.util import tf_inspect def parse_entity(entity): - """Return the AST of given entity.""" + """Returns the AST of given entity.""" source = tf_inspect.getsource(entity) source = textwrap.dedent(source) return parse_str(source), source def parse_str(src): - """Return the AST of given piece of code.""" + """Returns the AST of given piece of code.""" return gast.parse(src) + + +def parse_expression(src): + """Returns the AST of given identifier. + + Args: + src: A piece of code that represents a single Python expression + Returns: + A gast.AST object. + Raises: + ValueError: if src does not consist of a single Expression. + """ + node = parse_str(src) + assert isinstance(node, gast.Module) + if len(node.body) != 1 and not isinstance(node.body[0], gast.Expr): + raise ValueError( + 'Expected a single expression, found instead %s' % node.body) + return node.body[0].value diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/autograph/pyct/parser_test.py similarity index 80% rename from tensorflow/contrib/py2tf/pyct/parser_test.py rename to tensorflow/contrib/autograph/pyct/parser_test.py index f35dfa04c70dc191078248c32f9a04d28133129a..007a4c6fb0393b7235808478d55b3ffa469f85d0 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/autograph/pyct/parser_test.py @@ -20,28 +20,33 @@ from __future__ import print_function import textwrap -from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test -def f(x): - return x + 1 - - class ParserTest(test.TestCase): def test_parse_entity(self): + + def f(x): + return x + 1 + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): mod = parser.parse_str( textwrap.dedent(""" - def f(x): - return x + 1 + def f(x): + return x + 1 """)) self.assertEqual('f', mod.body[0].name) + def test_parse_expression(self): + node = parser.parse_expression('a.b') + self.assertEqual('a', node.value.id) + self.assertEqual('b', node.attr) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/autograph/pyct/pretty_printer.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/pretty_printer.py rename to tensorflow/contrib/autograph/pyct/pretty_printer.py diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/pretty_printer_test.py rename to tensorflow/contrib/autograph/pyct/pretty_printer_test.py index 81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99..0cb48f35760b7b2655eb5cf73017b70e28dae219 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py similarity index 91% rename from tensorflow/contrib/py2tf/pyct/qual_names.py rename to tensorflow/contrib/autograph/pyct/qual_names.py index 6bcbaeb2aeb3043919e84bc6599edf5aee583c6d..4d5764a974aac542ddf4a54a9acd36f1afcb0464 100644 --- a/tensorflow/contrib/py2tf/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -29,7 +29,7 @@ import collections import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno class Symbol(collections.namedtuple('Symbol', ['name'])): @@ -169,14 +169,6 @@ class QnResolver(gast.NodeTransformer): Note: Not using NodeAnnos to avoid circular dependencies. """ - def visit_Call(self, node): - node = self.generic_visit(node) - # This helps treat the following cases uniformly: - # a = b[i] - # a = b()[i] - anno.copyanno(node.func, node, anno.Basic.QN) - return node - def visit_Name(self, node): node = self.generic_visit(node) anno.setanno(node, anno.Basic.QN, QN(node.id)) @@ -184,8 +176,9 @@ class QnResolver(gast.NodeTransformer): def visit_Attribute(self, node): node = self.generic_visit(node) - anno.setanno(node, anno.Basic.QN, - QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) return node def visit_Subscript(self, node): @@ -201,9 +194,10 @@ class QnResolver(gast.NodeTransformer): subscript = QN(StringLiteral(s.value.s)) else: subscript = anno.getanno(node.slice.value, anno.Basic.QN) - anno.setanno(node, anno.Basic.QN, - QN(anno.getanno(node.value, anno.Basic.QN), - subscript=subscript)) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), + subscript=subscript)) return node diff --git a/tensorflow/contrib/py2tf/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py similarity index 89% rename from tensorflow/contrib/py2tf/pyct/qual_names_test.py rename to tensorflow/contrib/autograph/pyct/qual_names_test.py index f2cd8e98f02213c9035fdb5b20e0862f0a8fd3f6..103bd25aa380e9f61ecea9c5298f34df5157d629 100644 --- a/tensorflow/contrib/py2tf/pyct/qual_names_test.py +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -20,11 +20,11 @@ from __future__ import print_function import textwrap -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.qual_names import resolve +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.qual_names import resolve from tensorflow.python.platform import test @@ -208,6 +208,24 @@ class QNResolverTest(test.TestCase): self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h') + def test_function_calls(self): + samples = """ + a.b + a.b() + a().b + z[i] + z[i]() + z()[i] + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + self.assertQNStringIs(nodes[0], 'a.b') + self.assertQNStringIs(nodes[1].func, 'a.b') + self.assertQNStringIs(nodes[2].value.func, 'a') + self.assertQNStringIs(nodes[3], 'z[i]') + self.assertQNStringIs(nodes[4].func, 'z[i]') + self.assertQNStringIs(nodes[5].value.func, 'z') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD similarity index 80% rename from tensorflow/contrib/py2tf/pyct/static_analysis/BUILD rename to tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 2799b56a0042e99b8f8b38100d07c5afaef9f424..83f3bafc4217649db6499566d548c1657428ad0b 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -25,7 +25,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", ], ) @@ -34,9 +34,10 @@ py_test( name = "activity_test", srcs = ["activity_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", "@gast_archive//:gast", ], @@ -46,9 +47,10 @@ py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -59,8 +61,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py rename to tensorflow/contrib/autograph/pyct/static_analysis/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 87fc8c979c4e3310fb3aa82b0f23d909b0170cda..da6a2f6f0500ebba41b85d06dcc912aae9d68f97 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -22,10 +22,10 @@ import copy import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -171,6 +171,10 @@ class ActivityAnalizer(transformer.Base): self._in_return_statement = False def _track_symbol(self, node): + # This can happen when we have an attribute (or subscript) on a function + # call. Example: a().b + if not anno.hasanno(node, anno.Basic.QN): + return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py similarity index 95% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index b16d15b39d8eb4c444cbc50ae62baa3a8fcc7841..37c28872bb9fc4f0c6f95eec8145101b7a6c83de 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/static_analysis/annos.py rename to tensorflow/contrib/autograph/pyct/static_analysis/annos.py diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py similarity index 88% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 0388be5d252389f2f3516c8b27828905d6475589..53ae15459097baff918432a493edd7360ebf209d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -25,9 +25,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class LiveValueResolver(transformer.Base): @@ -55,11 +55,19 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__name__,)) + if hasattr(obj, '__name__'): + anno.setanno(node, 'fqn', (obj.__name__,)) + elif hasattr(obj, '__class__'): + obj_class = obj.__class__ + anno.setanno(node, 'fqn', + (obj_class.__module__, obj_class.__name__)) + else: + # If the symbol value is for example a primitive, then it will not + # have a name. + pass else: pass # TODO(mdan): Should we raise an error here? diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py similarity index 78% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index c133a455b3dd328689102634c6076f366212ac25..69e428bde109ed43c3cdda1a94970a832dc47852 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -18,13 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -57,13 +59,30 @@ class LiveValuesResolverTest(test.TestCase): def test_literals(self): + a = None + def test_fn(): - return Foo # pylint: disable=undefined-variable + return a - node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) + node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) + def test_primitive_values(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {'a': True}) + retval_node = node.body[0].body[0].value + if six.PY2: + self.assertEqual( + anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) + else: + self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool')) + def test_namespace(self): def foo(): diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py similarity index 93% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 5556a58c025da695bcef10352c597c7c8dd612d9..203aa3c3d18ab15300bbf424adeece6e74d9c994 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -43,8 +43,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -168,6 +168,15 @@ class TypeInfoResolver(transformer.Base): anno.getanno(definition, 'element_type')) return node + def _process_tuple_assignment(self, source, t): + for i, e in enumerate(t.elts): + if isinstance(e, gast.Tuple): + self._process_tuple_assignment(source, e) + else: + self.scope.setval( + anno.getanno(e, anno.Basic.QN), + gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func @@ -183,10 +192,9 @@ class TypeInfoResolver(transformer.Base): for t in targets: if isinstance(t, gast.Tuple): - for i, e in enumerate(t.elts): - self.scope.setval( - anno.getanno(e, anno.Basic.QN), - gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + # need to recurse on the case of assigning nested tuples, + # ex. a, (b, c) = f() + self._process_tuple_assignment(source, t) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py similarity index 86% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 0d9d5a85f055b170ea6e493e8ac185f1298ebf3c..c0de4a604301b6e9f80ee83e4797b9ac7e558a48 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test from tensorflow.python.training import training @@ -196,6 +196,23 @@ class TypeInfoResolverTest(test.TestCase): f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_nested_assignment(self): + + def test_fn(foo): + a, (b, c) = foo + return a, b, c + + node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)}) + lhs = node.body[0].body[1].value.elts + a = lhs[0] + b = lhs[1] + c = lhs[2] + # TODO(mdan): change these once we have the live values propagating + # correctly + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py similarity index 65% rename from tensorflow/contrib/py2tf/pyct/templates.py rename to tensorflow/contrib/autograph/pyct/templates.py index cdd71dc56de33cde46d6115085350a321093d792..baf7923fff7c786c1abd05e11fa6ffdb8c8f0912 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,9 +26,9 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -44,8 +44,6 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False - # TODO(mdan): Make a more detailed pass and clean up if needed. - def visit_Expr(self, node): if (isinstance(node.value, gast.Name) and node.value.id in self.replacements): @@ -53,17 +51,66 @@ class ReplaceTransformer(gast.NodeTransformer): self.generic_visit(node) return node + def visit_keyword(self, node): + if node.arg in self.replacements: + repl = self.replacements[node.arg] + if isinstance(repl, gast.keyword): + return repl + elif (isinstance(repl, (list, tuple)) and repl and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) + return self.generic_visit(node) + def visit_FunctionDef(self, node): node = self.generic_visit(node) if node.name in self.replacements: repl = self.replacements[node.name] if not isinstance(repl, (gast.Name, ast.Name)): raise ValueError( - 'A function name can only be replaced by a Name node. Found: %s' % + 'a function name can only be replaced by a Name node. Found: %s' % repl) node.name = repl.id return node + def _check_has_context(self, node): + if not node.ctx: + raise ValueError('node %s is missing ctx value' % node) + + def _check_inner_children_have_context(self, node): + if isinstance(node, gast.Attribute): + self._check_inner_children_have_context(node.value) + self._check_has_context(node) + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._check_inner_children_have_context(e) + self._check_has_context(node) + elif isinstance(node, gast.Dict): + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._check_inner_children_have_context(node.value) + self._check_inner_children_have_context(node.slice) + elif isinstance(node, gast.Slice): + self._check_inner_children_have_context(node.lower) + if node.upper: + self._check_inner_children_have_context(node.upper) + if node.step: + self._check_inner_children_have_context(node.step) + elif isinstance(node, gast.Name): + self._check_has_context(node) + elif isinstance(node, (gast.Str, gast.Num)): + pass + else: + raise ValueError('unexpected node type "%s"' % node) + def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, ctx) @@ -74,6 +121,24 @@ class ReplaceTransformer(gast.NodeTransformer): node.ctx = ctx elif isinstance(node, gast.Name): node.ctx = ctx + elif isinstance(node, gast.Call): + self._set_inner_child_context(node.func, ctx) + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for a in node.args: + self._check_inner_children_have_context(a) + for k in node.keywords: + self._check_inner_children_have_context(k.value) + elif isinstance(node, gast.Dict): + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._set_inner_child_context(node.value, ctx) + self._check_inner_children_have_context(node.slice) elif isinstance(node, (gast.Str, gast.Num)): pass else: diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py similarity index 70% rename from tensorflow/contrib/py2tf/pyct/templates_test.py rename to tensorflow/contrib/autograph/pyct/templates_test.py index d7835b80a7f53c3ba012d01cac34b68c57bfe348..a01f8bf04c4faa6ec1779e0fb306155d99f5bd09 100644 --- a/tensorflow/contrib/py2tf/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -22,8 +22,9 @@ import imp import gast -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates from tensorflow.python.platform import test @@ -96,6 +97,50 @@ class TemplatesTest(test.TestCase): with self.assertRaises(ValueError): templates.replace(template, foo=1) + def test_replace_call_keyword(self): + template = """ + def test_fn(): + def f(a, d, f): + return a + d + f + return f(1, kws=None) + """ + + source = parser.parse_expression('f(d=3, f=5)') + node = templates.replace(template, kws=source.keywords)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(9, result.test_fn()) + + with self.assertRaises(ValueError): + templates.replace(template, kws=[]) + templates.replace(template, kws=1) + + def test_replace_name_with_call(self): + template = """ + def test_fn(): + b = 5 + def g(a): + return 3 * a + def f(): + return g + return foo + """ + + source = parser.parse_expression('f()(b)') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(15, result.test_fn()) + + def test_replace_name_with_dict(self): + template = """ + def test_fn(): + return foo['bar'] + """ + + source = parser.parse_expression('{\'bar\': 3}') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn()) + def replace_as_expression(self): template = """ foo(a) diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py similarity index 77% rename from tensorflow/contrib/py2tf/pyct/transformer.py rename to tensorflow/contrib/autograph/pyct/transformer.py index 57016bb4ce84776dfc8dfbe380322a03eb4b37b8..35f114b6e11901a854c1d631061ae42285c0e261 100644 --- a/tensorflow/contrib/py2tf/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -23,14 +23,22 @@ import sys import gast import six -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import pretty_printer -class PyFlowParseError(SyntaxError): +class AutographParseError(SyntaxError): pass +def try_ast_to_source(node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + + class Base(gast.NodeTransformer): """Base class for specialized transformers.""" @@ -62,14 +70,15 @@ class Base(gast.NodeTransformer): return super(Base, self).visit(node) except (ValueError, AttributeError, KeyError, NotImplementedError, AssertionError) as e: - msg = '%s: %s\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False)) + msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), try_ast_to_source(node), + pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] else: line = '' - six.reraise(PyFlowParseError, - PyFlowParseError( + six.reraise(AutographParseError, + AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD similarity index 95% rename from tensorflow/contrib/py2tf/utils/BUILD rename to tensorflow/contrib/autograph/utils/BUILD index 8bc338e801aa283967f4f6e6a659df9683cbc154..d3a1b9468892531cbc51bc13de66ef595f1a95f8 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -35,6 +35,7 @@ py_library( deps = [ "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", + "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], ) @@ -43,6 +44,7 @@ py_test( name = "builtins_test", srcs = ["builtins_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -83,7 +85,7 @@ py_test( name = "py_func_test", srcs = ["py_func_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22898b17e98bb004b4d2aa529b58cc99fc64dbb2 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin +from tensorflow.contrib.autograph.utils.builtins import dynamic_dataset +from tensorflow.contrib.autograph.utils.builtins import dynamic_for_cond +from tensorflow.contrib.autograph.utils.builtins import dynamic_print +from tensorflow.contrib.autograph.utils.builtins import dynamic_range +from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_while +from tensorflow.contrib.autograph.utils.py_func import wrap_py_func +from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append +from tensorflow.contrib.autograph.utils.testing import fake_tf +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/py2tf/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/builtins.py rename to tensorflow/contrib/autograph/utils/builtins.py index 3cb62b55d4d23545af4d641ecab1663ee7f7b876..c6af0e4d13b8d15bebf857ff7e1129149490ee7a 100644 --- a/tensorflow/contrib/py2tf/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -18,12 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + import six -from tensorflow.contrib.py2tf.utils import py_func -from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_inspect @@ -54,7 +58,6 @@ def dynamic_len(list_or_tensor): raise ValueError( 'len requires non-zero rank for tensor "%s"' % list_or_tensor) return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) @@ -96,4 +99,76 @@ def dynamic_print(*values): if all(map(is_tf_print_compatible, values)): return logging_ops.Print(1, values) - return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + + def flushed_print(*vals): + print(*vals) + sys.stdout.flush() + + return py_func.wrap_py_func( + flushed_print, None, values, use_dummy_return=True) + + +def dynamic_dataset(iterated): + """Implementartion of smart tf.data.Dataset epoch wrapping. + + The function checks if the input is a tf.data.Dataset and if so then wraps it + so that for each element it returns it also returns the current epoch the + dataset iteration is in, for two epochs. If the input is not a + tf.data.Dataset then it just returns the input. + + Args: + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + Either just the untouched input, or in the case of input being a + tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each + element it returns it also returns the current epoch the dataset iteration + is in. + """ + if not isinstance(iterated, dataset_ops.Dataset): + return iterated + + def epoch_dataset_number_helper(i): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(i).repeat(), iterated)) + + epoch_numbers = dataset_ops.Dataset.range(2) + return epoch_numbers.flat_map(epoch_dataset_number_helper) + + +def dynamic_for_cond(iteration, iterated): + """Implementartion of smart while-loop condition using dynamic dispatch. + + The function checks if it is iterating over a tf.data.Dataset or not, and in + the case it is not then it simply returns if we are still in range of the + iterated and the next element. If it is iterating over a dataset then it only + iterates for a single epoch. + + Args: + iteration: The current iteration of the loop. + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + A tuple of a bool that indicates whether the loop should continue, and the + next element in iterated. + """ + # TODO(znado): Clean up. + # TODO(znado): This won't work for unpacked iterates. Fix. + if isinstance(iterated, dataset_ops.Dataset): + curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next() + return math_ops.less(curr_epoch, 1), next_elem + elif tensor_util.is_tensor(iterated): + if iterated.shape.ndims > 1: + elem_shape = array_ops.shape(iterated)[1:] + else: + elem_shape = () + if iterated.shape.ndims == 0 or iterated.shape[0] == 0: + return False, array_ops.zeros(elem_shape, iterated.dtype) + return control_flow_ops.cond( + math_ops.less(iteration, dynamic_len(iterated)), + lambda: (True, iterated[iteration]), + lambda: (False, array_ops.zeros(elem_shape, iterated.dtype))) + elif hasattr(iterated, '__len__'): + if iteration < len(iterated): + return True, iterated[iteration] + return False, None + else: + raise NotImplementedError('Python iterators not yet supported.') diff --git a/tensorflow/contrib/py2tf/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py similarity index 98% rename from tensorflow/contrib/py2tf/utils/builtins_test.py rename to tensorflow/contrib/autograph/utils/builtins_test.py index 59b3573d38c5bd98f416c7b77d1bc772cb8069dd..d9f7913d89a5471c76eb7ae484674bd7a1853ac9 100644 --- a/tensorflow/contrib/py2tf/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -22,7 +22,7 @@ import sys import six -from tensorflow.contrib.py2tf.utils import builtins +from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/autograph/utils/context_managers.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/context_managers.py rename to tensorflow/contrib/autograph/utils/context_managers.py diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/autograph/utils/context_managers_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/context_managers_test.py rename to tensorflow/contrib/autograph/utils/context_managers_test.py index 404f6e44e59d8bd6131367e3234843f03b351910..42e27724b9856f715b524cdd7539897851715638 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers_test.py +++ b/tensorflow/contrib/autograph/utils/context_managers_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.contrib.autograph.utils import context_managers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import tensor_array_ops diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/autograph/utils/misc.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/misc.py rename to tensorflow/contrib/autograph/utils/misc.py diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/misc_test.py rename to tensorflow/contrib/autograph/utils/misc_test.py index 8aedd4cd64798660cc07364c45487399986c9be6..71e358c33e1ea9887d267c67bc80362bac26c3a6 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/autograph/utils/misc_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils.misc import alias_tensors +from tensorflow.contrib.autograph.utils.misc import alias_tensors from tensorflow.python.framework.constant_op import constant from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py similarity index 82% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch.py index da7a942703d83b55edbd1607cb49ad4137daeb04..47049255f31113a0c7b2f5a1269593afdbbc9b19 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for type-dependent behavior used in py2tf-generated code.""" +"""Utilities for type-dependent behavior used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division @@ -20,23 +20,18 @@ from __future__ import print_function import six -from tensorflow.contrib.py2tf.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_check import is_tensor from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops def dynamic_is(left, right): - if is_tensor(left, right): - return math_ops.equal(left.name, right.name) - else: - return left is right + # TODO(alexbw) if we're sure we should leave 'is' in place, + # then change the semantics in converters/logical_expressions.py + return left is right def dynamic_is_not(left, right): - if is_tensor(left, right): - return math_ops.not_equal(left.name, right.name) - else: - return left is not right + return left is not right def run_cond(condition, true_fn, false_fn): @@ -60,10 +55,17 @@ def run_cond(condition, true_fn, false_fn): def py_cond(condition, true_fn, false_fn): + """Functional version of Python's conditional.""" if condition: - return true_fn() + results = true_fn() else: - return false_fn() + results = false_fn() + + # The contract for the branch functions is to return tuples, but they should + # be collapsed to a single element when there is only one output. + if len(results) == 1: + return results[0] + return results def run_while(cond_fn, body_fn, init_args): diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py similarity index 86% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch_test.py index 8d89b6898a366fe90ee1d43a55d0a7f10690224b..e6a41bb4166e8cfc8c703685f56eb90a1b5f63b4 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.py2tf.utils import multiple_dispatch +from tensorflow.contrib.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant from tensorflow.python.platform import test @@ -50,26 +50,25 @@ class MultipleDispatchTest(test.TestCase): should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) - self.assertTrue(should_be_true1.eval()) - self.assertTrue(should_be_true2.eval()) - self.assertFalse(should_be_false1.eval()) - self.assertFalse(should_be_false2.eval()) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) def test_run_cond_python(self): - true_fn = lambda: 2.0 - false_fn = lambda: 3.0 - self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) - self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + true_fn = lambda: (2,) + false_fn = lambda: (3,) + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3) def test_run_cond_tf(self): - - true_fn = lambda: constant([2.0]) - false_fn = lambda: constant([3.0]) + true_fn = lambda: (constant(2),) + false_fn = lambda: (constant(3),) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) - self.assertEqual(sess.run(out), 2.0) + self.assertEqual(sess.run(out), 2) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) - self.assertEqual(sess.run(out), 3.0) + self.assertEqual(sess.run(out), 3) def test_run_while_python(self): cond_fn = lambda x, t, s: x > t diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/contrib/autograph/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..11ebfb2e49f0e762b56ae2cde2b76d2e24032d72 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +class MatchDType(namedtuple('MatchDType', ('arg_number',))): + """Allows matching the dtype of an argument. + + Used in conjunction with function calls. For example, MatchDType(0) will + match the DType of the first argument. + """ + + pass + + +def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: None, individual of tuple/list of DType or MatchDType, the + data type for each of f's return value(s). Set to None if f has no + return values or use_dummy_return is True. Use MatchDType to define a + dtype identical to that of `i`th argument (argument 0 is the first); + an argument must of Tensor type if it is to be used with MatchDType. + args: Positional arguments for f, as list or tuple. + kwargs: Keyword arguments for f, as dict with string keys. May be None. + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if any of the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + tensor_args = [] + tensor_args_idx = {} + + # Of the positional arguments, only grab the tensor ones to be passed through + # the py_func. + n_args = len(args) + arg_is_tensor = tuple(map(tensor_util.is_tensor, args)) + for i in range(n_args): + if arg_is_tensor[i]: + tensor_args_idx[i] = len(tensor_args) + tensor_args.append(args[i]) + + # We essentially take the tensor kwargs, if any, and add them to the list of + # positional arguments. The kwargs are then reconstructed inside the py_func. + # + # For example, if + # + # args = [Tensor(1), 'foo'] + # kwargs = {'a': Tensor(2), 'b': 'bar'} + # + # Then + # + # tensor_args = (Tensor(1), Tensor(2)) + # kwarg_keys = ('a', 'b') + if kwargs: + kwarg_keys = tuple(kwargs.keys()) + kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys} + for k in kwarg_keys: + if kwarg_is_tensor[k]: + tensor_args_idx[k] = len(tensor_args) + tensor_args.append(kwargs[k]) + else: + kwarg_keys = () + + # Set up return dtypes. + def match_arg_dtype(arg_number): + arg = args[arg_number] + if not arg_is_tensor[arg_number]: + raise ValueError( + 'argument %d was used with MatchDType and must be a tf.Tensor, but ' + 'was %s instead' % (arg_number, type(arg))) + return arg.dtype + + if return_dtypes: + if isinstance(return_dtypes, MatchDType): + return_dtypes = match_arg_dtype(return_dtypes.arg_number) + elif isinstance(return_dtypes, (list, tuple)): + return_dtypes = tuple( + match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a + for a in return_dtypes) + else: + assert isinstance(return_dtypes, dtypes.DType) + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a + for i, a in enumerate(args)) + f_kwargs = { + k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] + for i, k in enumerate(kwarg_keys) + } + retval = f(*f_args, **f_kwargs) + return 1 if use_dummy_return else retval + + return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64 + if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2468263142f14332e86db99d198ba0f5c633dc69 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (1, constant_op.constant(1), 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func( + test_fn, dtypes.int64, + (constant_op.constant(1), 1, constant_op.constant(1))) + self.assertEqual(3, sess.run(result)) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) + self.assertEqual(35, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())) + self.assertEqual(35, sess.run(result)) + + def test_wrap_py_func_kwargs(self): + + class TestClass(object): + + def __init__(self, foo): + self.foo = foo + + def test_fn(a, b, c, d): + return a * b.foo + c * d.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { + 'c': 11, + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass(5)), { + 'c': constant_op.constant(11), + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([1], side_counter) + result = py_func.wrap_py_func( + test_fn, None, (constant_op.constant(5),), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/autograph/utils/tensor_list.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/tensor_list.py rename to tensorflow/contrib/autograph/utils/tensor_list.py diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py similarity index 97% rename from tensorflow/contrib/py2tf/utils/tensor_list_test.py rename to tensorflow/contrib/autograph/utils/tensor_list_test.py index 110e4d105e934d9d752afc2ccc0c53c99b70d41d..d58489eb68b6b949a4276520605c62b7c2825558 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list_test.py +++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for PyFlow list.""" +"""Tests for Autograph lists.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.contrib.autograph.utils import tensor_list as tl from tensorflow.python.client.session import Session from tensorflow.python.eager import context from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/py2tf/utils/testing.py b/tensorflow/contrib/autograph/utils/testing.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/testing.py rename to tensorflow/contrib/autograph/utils/testing.py diff --git a/tensorflow/contrib/py2tf/utils/type_check.py b/tensorflow/contrib/autograph/utils/type_check.py similarity index 95% rename from tensorflow/contrib/py2tf/utils/type_check.py rename to tensorflow/contrib/autograph/utils/type_check.py index b9b2b451a4e22684a19f0d10fbf5e4fae5d6152b..8748abc47bcfb55b4d0b11178a46816249732da9 100644 --- a/tensorflow/contrib/py2tf/utils/type_check.py +++ b/tensorflow/contrib/autograph/utils/type_check.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used in py2tf-generated code.""" +"""Utilities used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/py2tf/utils/type_check_test.py b/tensorflow/contrib/autograph/utils/type_check_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/type_check_test.py rename to tensorflow/contrib/autograph/utils/type_check_test.py index 7d0428e9cccecdc67511e236bc00655a055aea29..3b67b7194c5656b193d47860f93986a985cb1aef 100644 --- a/tensorflow/contrib/py2tf/utils/type_check_test.py +++ b/tensorflow/contrib/autograph/utils/type_check_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy -from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/utils/type_hints.py b/tensorflow/contrib/autograph/utils/type_hints.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/type_hints.py rename to tensorflow/contrib/autograph/utils/type_hints.py diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index ee67909133fc26ba98355db05a4b90d3dfa6b97b..d65c990c87cbc316472237d183c03765416501e7 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -112,14 +112,3 @@ py_test( "//tensorflow/python:script_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD index 6db627faad1df4a4b73082e74e7754829ff2b514..7cb2d8079bd18660f72eab92654629434ce4d6a5 100644 --- a/tensorflow/contrib/batching/test_util/BUILD +++ b/tensorflow/contrib/batching/test_util/BUILD @@ -8,17 +8,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - cc_library( name = "fake_clock_env", testonly = 1, diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index 2a84a7712a8fa66e89db41ff4e7ebe4f620029ca..8f81b6702f2807d7da7e72190ce2d86b28e52113 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "**/google_*", - ], - ), -) - cc_library( name = "periodic_function_dynamic", hdrs = ["periodic_function.h"], diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index c6feec68e0104ff33451bbb6fa7de51d13e0a43c..5a2d7f6a3c0ba233299a5790fa80488786712f3c 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -37,25 +37,6 @@ py_library( ], ) -cuda_py_test( - name = "metropolis_hastings_test", - size = "large", - srcs = ["python/kernel_tests/metropolis_hastings_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - cuda_py_test( name = "monte_carlo_test", size = "small", @@ -76,37 +57,3 @@ cuda_py_test( "//tensorflow/python:random_seed", ], ) - -cuda_py_test( - name = "hmc_test", - size = "large", - srcs = ["python/kernel_tests/hmc_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["nomsan"], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/bayesflow/README.md b/tensorflow/contrib/bayesflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10323dc6d59918a9f8cf1840d06dcd219dfe3568 --- /dev/null +++ b/tensorflow/contrib/bayesflow/README.md @@ -0,0 +1,17 @@ +# Notice + +`tf.contrib.bayesflow` has moved! + +See new code at [github.com/tensorflow/probability]( +https://github.com/tensorflow/probability). + +Switch imports with: + +```python +# old +import tensorflow as tf +tfp = tf.contrib.bayesflow + +# new +import tensorflow_probability as tfp +``` diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index f86820382682f79e85e6a92c7f63fa15bb8be1a3..41a8c920fc4e81af90f4c94a149d8c404c58b747 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,8 +21,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo # pylint: enable=unused-import,line-too-long @@ -30,13 +28,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'entropy', - 'hmc', - 'metropolis_hastings', 'monte_carlo', - 'special_math', - 'stochastic_variables', - 'variational_inference', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py deleted file mode 100644 index dabadfc7b6a3da8786e88d559fe2d05b44599ca0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ /dev/null @@ -1,737 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Hamiltonian Monte Carlo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -import numpy as np -from scipy import stats - -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator - -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_linalg_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging_ops - - -def _reduce_variance(x, axis=None, keepdims=False): - sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) - return math_ops.reduce_mean( - math_ops.squared_difference(x, sample_mean), axis, keepdims) - - -class HMCTest(test.TestCase): - - def setUp(self): - self._shape_param = 5. - self._rate_param = 10. - - random_seed.set_random_seed(10003) - np.random.seed(10003) - - def assertAllFinite(self, x): - self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) - - def _log_gamma_log_prob(self, x, event_dims=()): - """Computes log-pdf of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - """ - return math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims) - - def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, - feed_dict=None): - step_size = array_ops.placeholder(np.float32, [], name="step_size") - hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict[hmc_lf_steps] = 1000 - - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - - m = random_ops.random_normal(array_ops.shape(x)) - log_prob_0 = self._log_gamma_log_prob(x, event_dims) - grad_0 = gradients_ops.gradients(log_prob_0, x) - old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) - - new_m, _, log_prob_1, _ = _leapfrog_integrator( - current_momentums=[m], - target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), - current_state_parts=[x], - step_sizes=[step_size], - num_leapfrog_steps=hmc_lf_steps, - current_target_log_prob=log_prob_0, - current_grads_target_log_prob=grad_0) - new_m = new_m[0] - - new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, - event_dims) - - x_shape = sess.run(x, feed_dict).shape - event_size = np.prod(x_shape[independent_chain_ndims:]) - feed_dict[step_size] = 0.1 / event_size - old_energy_, new_energy_ = sess.run([old_energy, new_energy], - feed_dict) - logging_ops.vlog(1, "average energy relative change: {}".format( - (1. - new_energy_ / old_energy_).mean())) - self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) - - def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): - """Tests the long-term energy conservation of the leapfrog integrator. - - The leapfrog integrator is symplectic, so for sufficiently small step - sizes it should be possible to run it more or less indefinitely without - the energy of the system blowing up or collapsing. - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._integrator_conserves_energy(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper(0) - - def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper(1) - - def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper(2) - - def testIntegratorEnergyConservation3(self): - self._integrator_conserves_energy_wrapper(3) - - def testSampleChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - num_results = 10 - independent_chain_ndims = 1 - - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - kwargs = dict( - target_log_prob_fn=log_gamma_log_prob, - current_state=np.random.rand(4, 3, 2), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=150, - seed=52, - ) - - samples0, kernel_results0 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=2 * num_results, - num_steps_between_results=0).items()))) - - samples1, kernel_results1 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=num_results, - num_steps_between_results=1).items()))) - - [ - samples0_, - samples1_, - target_log_prob0_, - target_log_prob1_, - ] = sess.run([ - samples0, - samples1, - kernel_results0.current_target_log_prob, - kernel_results1.current_target_log_prob, - ]) - self.assertAllClose(samples0_[::2], samples1_, - atol=1e-5, rtol=1e-5) - self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, - atol=1e-5, rtol=1e-5) - - def _chain_gets_correct_expectations(self, x, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - def log_gamma_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - num_results = array_ops.placeholder( - np.int32, [], name="num_results") - step_size = array_ops.placeholder( - np.float32, [], name="step_size") - num_leapfrog_steps = array_ops.placeholder( - np.int32, [], name="num_leapfrog_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict.update({num_results: 150, - step_size: 0.05, - num_leapfrog_steps: 2}) - - samples, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_gamma_log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - num_burnin_steps=150, - seed=42) - - self.assertAllEqual(dict(target_calls=2), counter) - - expected_x = (math_ops.digamma(self._shape_param) - - np.log(self._rate_param)) - - expected_exp_x = self._shape_param / self._rate_param - - log_accept_ratio_, samples_, expected_x_ = sess.run( - [kernel_results.log_accept_ratio, samples, expected_x], - feed_dict) - - actual_x = samples_.mean() - actual_exp_x = np.exp(samples_).mean() - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( - expected_x_, expected_exp_x)) - logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( - actual_x, actual_exp_x)) - self.assertNear(actual_x, expected_x_, 2e-2) - self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), - acceptance_probs > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), - acceptance_probs <= 1.) - - def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper(0) - - def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper(1) - - def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper(2) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, # Non-constant gradient. - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # This log_prob has the property that it is likely to attract - # the flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - # Why use a non-constant gradient? This ensures the leapfrog integrator - # will not be exact. - - num_results = 1000 - # Large step size, will give rejections due to integration error in addition - # to rejection due to going into a region of log_prob = -inf. - step_size = 0.1 - num_leapfrog_steps = 5 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_prob, - current_state=initial_state, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - seed=42) - - states_, kernel_results_ = sess.run([states, kernel_results]) - pstates_ = kernel_results_.proposed_state - - neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) - - # First: Test that the mathematical properties of the above log prob - # function in conjunction with HMC show up as expected in kernel_results_. - - # We better have log_prob = -inf some of the time. - self.assertLess(0, neg_inf_mask.sum()) - # We better have some rejections due to something other than -inf. - self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have accepted a decent amount, even near end of the chain. - self.assertLess( - 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in states or log_prob. - # We may have some NaN in grads, which involve multiplication/addition due - # to gradient rules. This is the known "NaN grad issue with tf.where." - self.assertAllEqual(np.zeros_like(states_), - np.isnan(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states, grads, or log_prob. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual( - np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # Second: Test that kernel_results is congruent with itself and - # acceptance/rejection of states. - - # Proposed state is negative iff proposed target log prob is -inf. - np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) - np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) - - # Acceptance probs are zero whenever proposed state is negative. - acceptance_probs = np.exp(np.minimum( - kernel_results_.log_accept_ratio, 0.)) - self.assertAllEqual( - np.zeros_like(pstates_[neg_inf_mask]), - acceptance_probs[neg_inf_mask]) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[kernel_results_.is_accepted], - pstates_[kernel_results_.is_accepted], - ) - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if kernel_results_.is_accepted[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def _kernel_leaves_target_invariant(self, initial_draws, - independent_chain_ndims, - sess, feed_dict=None): - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - def fake_log_prob(x): - """Cooled version of the target distribution.""" - return 1.1 * log_gamma_log_prob(x) - - step_size = array_ops.placeholder(np.float32, [], name="step_size") - - if feed_dict is None: - feed_dict = {} - - feed_dict[step_size] = 0.4 - - sample, kernel_results = hmc.kernel( - target_log_prob_fn=log_gamma_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=43) - - bad_sample, bad_kernel_results = hmc.kernel( - target_log_prob_fn=fake_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=44) - - [ - log_accept_ratio_, - bad_log_accept_ratio_, - initial_draws_, - updated_draws_, - fake_draws_, - ] = sess.run([ - kernel_results.log_accept_ratio, - bad_kernel_results.log_accept_ratio, - initial_draws, - sample, - bad_sample, - ], feed_dict) - - # Confirm step size is small enough that we usually accept. - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - bad_acceptance_probs = np.exp(np.minimum(bad_log_accept_ratio_, 0.)) - self.assertGreater(acceptance_probs.mean(), 0.5) - self.assertGreater(bad_acceptance_probs.mean(), 0.5) - - # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs.mean(), 0.99) - self.assertLess(bad_acceptance_probs.mean(), 0.99) - - _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), - updated_draws_.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), - fake_draws_.flatten()) - - logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs.mean())) - logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs.mean())) - logging_ops.vlog(1, "K-S p-value for true target: {}".format( - ks_p_value_true)) - logging_ops.vlog(1, "K-S p-value for fake target: {}".format( - ks_p_value_fake)) - # Make sure that the MCMC update hasn't changed the empirical CDF much. - self.assertGreater(ks_p_value_true, 1e-3) - # Confirm that targeting the wrong distribution does - # significantly change the empirical CDF. - self.assertLess(ks_p_value_fake, 1e-6) - - def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): - """Tests that the kernel leaves the target distribution invariant. - - Draws some independent samples from the target distribution, - applies an iteration of the MCMC kernel, then runs a - Kolmogorov-Smirnov test to determine if the distribution of the - MCMC-updated samples has changed. - - We also confirm that running the kernel with a different log-pdf - does change the target distribution. (And that we can detect that.) - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - initial_draws = np.log(np.random.gamma(self._shape_param, - size=[50000, 2, 2])) - initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name="x_ph") - - feed_dict = {x_ph: initial_draws} - - self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper(1) - - def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper(2) - - def testKernelLeavesTargetInvariant3(self): - self._kernel_leaves_target_invariant_wrapper(3) - - def testNanRejection(self): - """Tests that an update that yields NaN potentials gets rejected. - - We run HMC with a target distribution that returns NaN - log-likelihoods if any element of x < 0, and unit-scale - exponential log-likelihoods otherwise. The exponential potential - pushes x towards 0, ensuring that any reasonably large update will - push us over the edge into NaN territory. - """ - def _unbounded_exponential_log_prob(x): - """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where( - x < 0., - array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), - -x) - return math_ops.reduce_sum(per_element_potentials) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_unbounded_exponential_log_prob, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=46) - initial_x_, updated_x_, log_accept_ratio_ = sess.run( - [initial_x, updated_x, kernel_results.log_accept_ratio]) - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs, 0.) - - def testNanFromGradsDontPropagate(self): - """Test that update with NaN gradients does not cause NaN in results.""" - def _nan_log_prob_with_nan_gradient(x): - return np.nan * math_ops.reduce_sum(x) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_nan_log_prob_with_nan_gradient, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=47) - initial_x_, updated_x_, log_accept_ratio_ = sess.run( - [initial_x, updated_x, kernel_results.log_accept_ratio]) - acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs, 0.) - - self.assertAllFinite( - gradients_ops.gradients(updated_x, initial_x)[0].eval()) - self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, initial_x)]) - self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, - kernel_results.proposed_state)]) - - # Gradients of the acceptance probs and new log prob are not finite. - # self.assertAllFinite( - # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) - # self.assertAllFinite( - # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) - - def _testChainWorksDtype(self, dtype): - with self.test_session(graph=ops.Graph()) as sess: - states, kernel_results = hmc.sample_chain( - num_results=10, - target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), - current_state=np.zeros(5).astype(dtype), - step_size=0.01, - num_leapfrog_steps=10, - seed=48) - states_, log_accept_ratio_ = sess.run( - [states, kernel_results.log_accept_ratio]) - self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, log_accept_ratio_.dtype) - - def testChainWorksIn64Bit(self): - self._testChainWorksDtype(np.float64) - - def testChainWorksIn16Bit(self): - self._testChainWorksDtype(np.float16) - - def testChainWorksCorrelatedMultivariate(self): - dtype = np.float32 - true_mean = dtype([0, 0]) - true_cov = dtype([[1, 0.5], - [0.5, 1]]) - num_results = 2000 - counter = collections.Counter() - with self.test_session(graph=ops.Graph()) as sess: - def target_log_prob(x, y): - counter["target_calls"] += 1 - # Corresponds to unnormalized MVN. - # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) - z = array_ops.stack([x, y], axis=-1) - true_mean - z = array_ops.squeeze( - gen_linalg_ops.matrix_triangular_solve( - np.linalg.cholesky(true_cov), - z[..., array_ops.newaxis]), - axis=-1) - return -0.5 * math_ops.reduce_sum(z**2., axis=-1) - states, _ = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=target_log_prob, - current_state=[dtype(-2), dtype(2)], - step_size=[0.5, 0.5], - num_leapfrog_steps=2, - num_burnin_steps=200, - num_steps_between_results=1, - seed=54) - self.assertAllEqual(dict(target_calls=2), counter) - states = array_ops.stack(states, axis=-1) - self.assertEqual(num_results, states.shape[0].value) - sample_mean = math_ops.reduce_mean(states, axis=0) - x = states - sample_mean - sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) - [sample_mean_, sample_cov_] = sess.run([ - sample_mean, sample_cov]) - self.assertAllClose(true_mean, sample_mean_, - atol=0.05, rtol=0.) - self.assertAllClose(true_cov, sample_cov_, - atol=0., rtol=0.1) - - -class _EnergyComputationTest(object): - - def testHandlesNanFromPotential(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - target_log_prob, proposed_target_log_prob = [ - self.dtype(x.flatten()) for x in np.meshgrid(x, x)] - num_chains = len(target_log_prob) - dummy_momentums = [-1, 1] - momentums = [self.dtype([dummy_momentums] * num_chains)] - proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - self.assertAllEqual(np.ones_like(grads_).astype(np.bool), - np.isfinite(grads_)) - - def testHandlesNanFromKinetic(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - momentums, proposed_momentums = [ - [np.reshape(self.dtype(x), [-1, 1])] - for x in np.meshgrid(x, x)] - num_chains = len(momentums[0]) - target_log_prob = np.ones(num_chains, self.dtype) - proposed_target_log_prob = np.ones(num_chains, self.dtype) - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - g = grads_[0].reshape([len(x), len(x)])[:, 0] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) - - # The remaining gradients are nan because the momentum was itself nan or - # inf. - g = grads_[0].reshape([len(x), len(x)])[:, 1:] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) - - -class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): - dtype = np.float16 - - -class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): - dtype = np.float32 - - -class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): - dtype = np.float64 - - -class _HMCHandlesLists(object): - - def testStateParts(self): - with self.test_session(graph=ops.Graph()) as sess: - dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) - dist_y = independent_lib.Independent( - gamma_lib.Gamma(concentration=self.dtype([1, 2]), - rate=self.dtype([0.5, 0.75])), - reinterpreted_batch_ndims=1) - def target_log_prob(x, y): - return dist_x.log_prob(x) + dist_y.log_prob(y) - x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] - samples, _ = hmc.sample_chain( - num_results=int(2e3), - target_log_prob_fn=target_log_prob, - current_state=x0, - step_size=0.85, - num_leapfrog_steps=3, - num_burnin_steps=int(250), - seed=49) - actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] - actual_vars = [_reduce_variance(s, axis=0) for s in samples] - expected_means = [dist_x.mean(), dist_y.mean()] - expected_vars = [dist_x.variance(), dist_y.variance()] - [ - actual_means_, - actual_vars_, - expected_means_, - expected_vars_, - ] = sess.run([ - actual_means, - actual_vars, - expected_means, - expected_vars, - ]) - self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) - self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) - - -class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): - dtype = np.float32 - - -class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py deleted file mode 100644 index f508e5b114a55fc1aeb07212595fda45fc308c7b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Metropolis-Hastings.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh -from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -class MetropolisHastingsTest(test.TestCase): - - def testKernelStateTensor(self): - """Test that transition kernel works with tensor input to `state`.""" - loc = variable_scope.get_variable("loc", initializer=0.) - - def target_log_prob_fn(loc): - return normal_lib.Normal(loc=0.0, scale=0.1).log_prob(loc) - - new_state, _ = mh.kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=mh.proposal_normal(scale=0.05), - current_state=loc, - seed=231251) - loc_update = loc.assign(new_state) - - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - loc_samples = [] - for _ in range(2500): - loc_sample = sess.run(loc_update) - loc_samples.append(loc_sample) - loc_samples = loc_samples[500:] # drop samples for burn-in - - self.assertAllClose(np.mean(loc_samples), 0.0, rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_samples), 0.1, rtol=1e-5, atol=1e-1) - - def testKernelStateList(self): - """Test that transition kernel works with list input to `state`.""" - num_chains = 2 - loc_one = variable_scope.get_variable( - "loc_one", [num_chains], - initializer=init_ops.zeros_initializer()) - loc_two = variable_scope.get_variable( - "loc_two", [num_chains], initializer=init_ops.zeros_initializer()) - - def target_log_prob_fn(loc_one, loc_two): - loc = array_ops.stack([loc_one, loc_two]) - log_prob = mvn_tril_lib.MultivariateNormalTriL( - loc=constant_op.constant([0., 0.]), - scale_tril=constant_op.constant([[0.1, 0.1], [0.0, 0.1]])).log_prob( - loc) - return math_ops.reduce_sum(log_prob, 0) - - def proposal_fn(loc_one, loc_two): - loc_one_proposal = mh.proposal_normal(scale=0.05) - loc_two_proposal = mh.proposal_normal(scale=0.05) - loc_one_sample, _ = loc_one_proposal(loc_one) - loc_two_sample, _ = loc_two_proposal(loc_two) - return [loc_one_sample, loc_two_sample], None - - new_state, _ = mh.kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=proposal_fn, - current_state=[loc_one, loc_two], - seed=12415) - loc_one_update = loc_one.assign(new_state[0]) - loc_two_update = loc_two.assign(new_state[1]) - - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - loc_one_samples = [] - loc_two_samples = [] - for _ in range(10000): - loc_one_sample, loc_two_sample = sess.run( - [loc_one_update, loc_two_update]) - loc_one_samples.append(loc_one_sample) - loc_two_samples.append(loc_two_sample) - - loc_one_samples = np.array(loc_one_samples) - loc_two_samples = np.array(loc_two_samples) - loc_one_samples = loc_one_samples[1000:] # drop samples for burn-in - loc_two_samples = loc_two_samples[1000:] # drop samples for burn-in - - self.assertAllClose(np.mean(loc_one_samples, 0), - np.array([0.] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.mean(loc_two_samples, 0), - np.array([0.] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_one_samples, 0), - np.array([0.1] * num_chains), - rtol=1e-5, atol=1e-1) - self.assertAllClose(np.std(loc_two_samples, 0), - np.array([0.1] * num_chains), - rtol=1e-5, atol=1e-1) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # The truncated distribution has the property that it is likely to attract - # the flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - - num_results = 1000 - # Large step size, will give rejections due to going into a region of - # log_prob = -inf. - step_size = 0.3 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states = [] - is_accepted = [] - proposed_states = [] - current_state = initial_state - for _ in range(num_results): - current_state, kernel_results = mh.kernel( - target_log_prob_fn=log_prob, - proposal_fn=mh.proposal_uniform(step_size=step_size), - current_state=current_state, - seed=42) - states.append(current_state) - proposed_states.append(kernel_results.proposed_state) - is_accepted.append(kernel_results.is_accepted) - - states = array_ops.stack(states) - proposed_states = array_ops.stack(proposed_states) - is_accepted = array_ops.stack(is_accepted) - states_, pstates_, is_accepted_ = sess.run( - [states, proposed_states, is_accepted]) - - # We better have accepted a decent amount, even near end of the chain. - self.assertLess( - 0.1, is_accepted_[int(0.9 * num_results):].mean()) - # We better not have any NaNs in states. - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[is_accepted_], - pstates_[is_accepted_], - ) - - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if is_accepted_[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def testDensityIncreasingStepAccepted(self): - """Tests that if a transition increases density, it is always accepted.""" - target_log_density = lambda x: - x * x - state = variable_scope.get_variable("state", initializer=10.) - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - "log_accept_ratio", initializer=0.) - - get_next_proposal = lambda x: (x - 1., None) - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, get_next_proposal, seed=1234) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - for j in range(9): - sess.run(step) - sample = sess.run(state) - sample_log_density = sess.run(state_log_density) - self.assertAlmostEqual(sample, 9 - j) - self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - - def testSampleProperties(self): - """Tests that the samples converge to the target distribution.""" - - def target_log_density(x): - """Log-density corresponding to a normal distribution with mean = 4.""" - return - (x - 2.0) * (x - 2.0) * 0.5 - - # Use the uniform random walker to generate proposals. - proposal_fn = mh.proposal_uniform( - step_size=1.0, seed=1234) - - state = variable_scope.get_variable("state", initializer=0.0) - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - "log_accept_ratio", initializer=0.) - - # Random walk MCMC converges slowly so need to put in enough iterations. - num_iterations = 5000 - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, proposal_fn, seed=4321) - - init = variables.global_variables_initializer() - - sample_sum, sample_sq_sum = 0.0, 0.0 - with self.test_session() as sess: - sess.run(init) - for _ in np.arange(num_iterations): - # Allow for the mixing of the chain and discard these samples. - sess.run(step) - for _ in np.arange(num_iterations): - sess.run(step) - sample = sess.run(state) - sample_sum += sample - sample_sq_sum += sample * sample - - sample_mean = sample_sum / num_iterations - sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean - # The samples have large autocorrelation which reduces the effective sample - # size. - self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) - self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - - def testProposalNormal(self): - """Tests that the normal proposals are correctly distributed.""" - - initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.proposal_normal( - scale=2.0, seed=1234) - proposal_points, _ = proposal_fn(initial_points) - - with self.test_session() as sess: - sample = sess.run(proposal_points) - - # It is expected that the elements in proposal_points have the same mean as - # initial_points and have the standard deviation that was supplied to the - # proposal scheme. - self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) - self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - - def testDocstringExample(self): - """Tests the simplified docstring example with multiple chains.""" - - n = 2 # dimension of the problem - - # Generate 300 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = variable_scope.get_variable( - "state", initializer=random_ops.random_normal( - [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = variable_scope.get_variable( - "state_log_density", - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = variable_scope.get_variable( - "log_acceptance_ratio", - initializer=array_ops.zeros([300], dtype=dtypes.float32)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + random_ops.random_uniform( - array_ops.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps. - for _ in range(10): - sess.run(stepper) - samples = sess.run(state) - covariance = np.eye(n) - # Verify that the estimated mean and covariance are close to the true - # values. - self.assertAlmostEqual( - np.max(np.abs(np.mean(samples, 0) - - np.zeros(n))), 0, - delta=0.1) - self.assertAlmostEqual( - np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2]) - - np.reshape(covariance, [n**2]))), 0, - delta=0.2) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py deleted file mode 100644 index 66afcc749746ab5c04114e585c5f93a3f3354d86..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ /dev/null @@ -1,961 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. - -@@sample_chain -@@kernel -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import util as distributions_util - -__all__ = [ - "sample_chain", - "kernel", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "log_accept_ratio", - "current_grads_target_log_prob", # "Current result" means "accepted". - "current_target_log_prob", # "Current result" means "accepted". - "is_accepted", - "proposed_grads_target_log_prob", - "proposed_state", - "proposed_target_log_prob", - ]) - - -def _make_dummy_kernel_results( - dummy_state, - dummy_target_log_prob, - dummy_grads_target_log_prob): - return KernelResults( - log_accept_ratio=dummy_target_log_prob, - current_grads_target_log_prob=dummy_grads_target_log_prob, - current_target_log_prob=dummy_target_log_prob, - is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), - proposed_grads_target_log_prob=dummy_grads_target_log_prob, - proposed_state=dummy_state, - proposed_target_log_prob=dummy_target_log_prob, - ) - - -def sample_chain( - num_results, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - num_burnin_steps=0, - num_steps_between_results=0, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm - that takes a series of gradient-informed steps to produce a Metropolis - proposal. This function samples from an HMC Markov chain at `current_state` - and whose stationary distribution has log-unnormalized-density - `target_log_prob_fn()`. - - This function samples from multiple chains in parallel. It assumes that the - the leftmost dimensions of (each) `current_state` (part) index an independent - chain. The function `target_log_prob_fn()` sums log-probabilities across - event dimensions (i.e., current state (part) rightmost dimensions). Each - element of the output of `target_log_prob_fn()` represents the (possibly - unnormalized) log-probability of the joint distribution over (all) the current - state (parts). - - The `current_state` can be represented as a single `Tensor` or a `list` of - `Tensors` which collectively represent the current state. When specifying a - `list`, one must also specify a list of `step_size`s. - - Note: `target_log_prob_fn` is called exactly twice. - - Since HMC states are correlated, it is sometimes desirable to produce - additional intermediate states, and then discard them, ending up with a set of - states with decreased autocorrelation. See [1]. Such "thinning" is made - possible by setting `num_steps_between_results > 0`. The chain then takes - `num_steps_between_results` extra steps between the steps that make it into - the results. The extra steps are never materialized (in calls to `sess.run`), - and thus do not increase memory requirements. - - [1]: "Statistically efficient thinning of a Markov chain sampler." - Art B. Owen. April 2017. - http://statweb.stanford.edu/~owen/reports/bestthinning.pdf - - #### Examples: - - ##### Sample from a diagonal-variance Gaussian. - - ```python - tfd = tf.contrib.distributions - - def make_likelihood(true_variances): - return tfd.MultivariateNormalDiag( - scale_diag=tf.sqrt(true_variances)) - - dims = 10 - dtype = np.float32 - true_variances = tf.linspace(dtype(1), dtype(3), dims) - likelihood = make_likelihood(true_variances) - - states, kernel_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=likelihood.log_prob, - current_state=tf.zeros(dims), - step_size=0.5, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(states, axis=0) - sample_var = tf.reduce_mean( - tf.squared_difference(states, sample_mean), - axis=0) - ``` - - ##### Sampling from factor-analysis posteriors with known factors. - - I.e., - - ```none - for i=1..n: - w[i] ~ Normal(0, eye(d)) # prior - x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood - ``` - - where `F` denotes factors. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, factors): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) - - # Setup data. - num_weights = 10 - num_factors = 4 - num_chains = 100 - dtype = np.float32 - - prior = make_prior(num_weights, dtype) - weights = prior.sample(num_chains) - factors = np.random.randn(num_factors, num_weights).astype(dtype) - x = make_likelihood(weights, factors).sample(num_chains) - - def target_log_prob(w): - # Target joint is: `f(w) = p(w, x | factors)`. - return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) - - # Get `num_results` samples from `num_chains` independent chains. - chains_states, kernels_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target_log_prob, - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) - sample_var = tf.reduce_mean( - tf.squared_difference(chains_states, sample_mean), - axis=[0, 1]) - ``` - - Args: - num_results: Integer number of Markov chain draws. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - num_burnin_steps: Integer number of chain steps to take before starting to - collect results. - Default value: 0 (i.e., no burn-in). - num_steps_between_results: Integer number of chain steps between collecting - a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. The number of returned chain states is - still equal to `num_results`. Default value: 0 (i.e., no thinning). - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob` at the `current_state` and wrt - the `current_state`. Must have same shape as `current_state`. The only - reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_chain"). - - Returns: - next_states: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state` but with a prepended `num_results`-size dimension. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - with ops.name_scope( - name, "hmc_sample_chain", - [num_results, current_state, step_size, num_leapfrog_steps, - num_burnin_steps, num_steps_between_results, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob, - ] = _prepare_args( - target_log_prob_fn, - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob) - num_results = ops.convert_to_tensor( - num_results, - dtype=dtypes.int32, - name="num_results") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - num_burnin_steps = ops.convert_to_tensor( - num_burnin_steps, - dtype=dtypes.int32, - name="num_burnin_steps") - num_steps_between_results = ops.convert_to_tensor( - num_steps_between_results, - dtype=dtypes.int32, - name="num_steps_between_results") - - def _run_chain(num_steps, current_state, kernel_results): - """Runs the chain(s) for `num_steps`.""" - def _loop_body(iter_, current_state, kernel_results): - return [iter_ + 1] + list(kernel( - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), - current_state, - kernel_results, - ], - ) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - return control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - def _scan_body(args_list, iter_): - """Closure which implements `tf.scan` body.""" - current_state, kernel_results = args_list - return _run_chain( - 1 + array_ops.where(math_ops.equal(iter_, 0), - num_burnin_steps, - num_steps_between_results), - current_state, - kernel_results) - - scan_kwargs = dict( - fn=_scan_body, - elems=math_ops.range(num_results), # iter_: used to choose burnin. - initializer=[ - current_state, - _make_dummy_kernel_results( - current_state, - current_target_log_prob, - current_grads_target_log_prob), - ]) - if seed is not None: - scan_kwargs["parallel_iterations"] = 1 - return functional_ops.scan(**scan_kwargs) - - -def kernel(target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs one iteration of Hamiltonian Monte Carlo. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function applies one step of HMC to - randomly update the variable `x`. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - #### Examples: - - ##### Simple chain with warm-up. - - ```python - tfd = tf.contrib.distributions - - # Tuning acceptance rates: - dtype = np.float32 - target_accept_rate = 0.631 - num_warmup_iter = 500 - num_chain_iter = 500 - - x = tf.get_variable(name="x", initializer=dtype(1)) - step_size = tf.get_variable(name="step_size", initializer=dtype(1)) - - target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - - next_x, other_results = hmc.kernel( - target_log_prob_fn=target.log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=3)[:4] - - x_update = x.assign(next_x) - - step_size_update = step_size.assign_add( - step_size * tf.where( - tf.exp(tf.minimum(other_results.log_accept_ratio), 0.) > - target_accept_rate, - 0.01, -0.01)) - - warmup = tf.group([x_update, step_size_update]) - - tf.global_variables_initializer().run() - - sess.graph.finalize() # No more graph building. - - # Warm up the sampler and adapt the step size - for _ in xrange(num_warmup_iter): - sess.run(warmup) - - # Collect samples without adapting step size - samples = np.zeros([num_chain_iter]) - for i in xrange(num_chain_iter): - _, x_, target_log_prob_, grad_ = sess.run([ - x_update, - x, - other_results.target_log_prob, - other_results.grads_target_log_prob]) - samples[i] = x_ - - print(samples.mean(), samples.std()) - ``` - - ##### Sample from more complicated posterior. - - I.e., - - ```none - W ~ MVN(loc=0, scale=sigma * eye(dims)) - for i=1...num_samples: - X[i] ~ MVN(loc=0, scale=eye(dims)) - eps[i] ~ Normal(loc=0, scale=1) - Y[i] = X[i].T * W + eps[i] - ``` - - ```python - tfd = tf.contrib.distributions - - def make_training_data(num_samples, dims, sigma): - dt = np.asarray(sigma).dtype - zeros = tf.zeros(dims, dtype=dt) - x = tfd.MultivariateNormalDiag( - loc=zeros).sample(num_samples, seed=1) - w = tfd.MultivariateNormalDiag( - loc=zeros, - scale_identity_multiplier=sigma).sample(seed=2) - noise = tfd.Normal( - loc=dt(0), - scale=dt(1)).sample(num_samples, seed=3) - y = tf.tensordot(x, w, axes=[[1], [0]]) + noise - return y, x, w - - def make_prior(sigma, dims): - # p(w | sigma) - return tfd.MultivariateNormalDiag( - loc=tf.zeros([dims], dtype=sigma.dtype), - scale_identity_multiplier=sigma) - - def make_likelihood(x, w): - # p(y | x, w) - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(x, w, axes=[[1], [0]])) - - # Setup assumptions. - dtype = np.float32 - num_samples = 150 - dims = 10 - num_iters = int(5e3) - - true_sigma = dtype(0.5) - y, x, true_weights = make_training_data(num_samples, dims, true_sigma) - - # Estimate of `log(true_sigma)`. - log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) - sigma = tf.exp(log_sigma) - - # State of the Markov chain. - weights = tf.get_variable( - name="weights", - initializer=np.random.randn(dims).astype(dtype)) - - prior = make_prior(sigma, dims) - - def joint_log_prob_fn(w): - # f(w) = log p(w, y | x) - return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) - - weights_update = weights.assign( - hmc.kernel(target_log_prob_fn=joint_log_prob, - current_state=weights, - step_size=0.1, - num_leapfrog_steps=5)[0]) - - with tf.control_dependencies([weights_update]): - loss = -prior.log_prob(weights) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) - - sess.graph.finalize() # No more graph building. - - tf.global_variables_initializer().run() - - sigma_history = np.zeros(num_iters, dtype) - weights_history = np.zeros([num_iters, dims], dtype) - - for i in xrange(num_iters): - _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) - weights_history[i, :] = weights_ - sigma_history[i] = sigma_ - - true_weights_ = sess.run(true_weights) - - # Should converge to something close to true_sigma. - plt.plot(sigma_history); - plt.ylabel("sigma"); - plt.xlabel("iteration"); - ``` - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `current_target_log_prob` at the `current_state` - and wrt the `current_state`. Must have same shape as `current_state`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_kernel"). - - Returns: - next_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - Raises: - ValueError: if there isn't one `step_size` or a list with same length as - `current_state`. - """ - with ops.name_scope( - name, "hmc_kernel", - [current_state, step_size, num_leapfrog_steps, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [current_state_parts, step_sizes, current_target_log_prob, - current_grads_target_log_prob] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob, - maybe_expand=True) - independent_chain_ndims = distributions_util.prefer_static_rank( - current_target_log_prob) - current_momentums = [] - for s in current_state_parts: - current_momentums.append(random_ops.random_normal( - shape=array_ops.shape(s), - dtype=s.dtype.base_dtype, - seed=seed)) - seed = distributions_util.gen_new_seed( - seed, salt="hmc_kernel_momentums") - - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] = _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob, - current_grads_target_log_prob) - - energy_change = _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims) - log_accept_ratio = -energy_change - - # u < exp(log_accept_ratio), where u~Uniform[0,1) - # ==> log(u) < log_accept_ratio - random_value = random_ops.random_uniform( - shape=array_ops.shape(energy_change), - dtype=energy_change.dtype, - seed=seed) - random_negative = math_ops.log(random_value) - is_accepted = random_negative < log_accept_ratio - - accepted_target_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - - next_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] - - accepted_grads_target_log_prob = [ - _choose(is_accepted, - proposed_grad, - grad, - independent_chain_ndims) - for proposed_grad, grad - in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] - - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(next_state_parts), - KernelResults( - log_accept_ratio=log_accept_ratio, - current_grads_target_log_prob=accepted_grads_target_log_prob, - current_target_log_prob=accepted_target_log_prob, - is_accepted=is_accepted, - proposed_grads_target_log_prob=proposed_grads_target_log_prob, - proposed_state=maybe_flatten(proposed_state_parts), - proposed_target_log_prob=proposed_target_log_prob, - ), - ] - - -def _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Applies `num_leapfrog_steps` of the leapfrog integrator. - - Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - - #### Examples: - - ##### Simple quadratic potential. - - ```python - tfd = tf.contrib.distributions - - dims = 10 - num_iter = int(1e3) - dtype = np.float32 - - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - - [ - next_momentums, - next_positions, - ] = hmc._leapfrog_integrator( - current_momentums=[momentum], - target_log_prob_fn=tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)).log_prob, - current_state_parts=[position], - step_sizes=0.1, - num_leapfrog_steps=3)[:2] - - sess.graph.finalize() # No more graph building. - - momentum_ = np.random.randn(dims).astype(dtype) - position_ = np.random.randn(dims).astype(dtype) - - positions = np.zeros([num_iter, dims], dtype) - for i in xrange(num_iter): - position_, momentum_ = sess.run( - [next_momentums[0], next_position[0]], - feed_dict={position: position_, momentum: momentum_}) - positions[i] = position_ - - plt.plot(positions[:, 0]); # Sinusoidal. - ``` - - Args: - current_momentums: Tensor containing the value(s) of the momentum - variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like - `*current_state_parts` and returns its (possibly unnormalized) log-density - under the target distribution. - current_state_parts: Python `list` of `Tensor`s representing the current - state(s) of the Markov chain(s). The first `independent_chain_ndims` of - the `Tensor`(s) index different chains. - step_sizes: Python `list` of `Tensor`s representing the step size for the - leapfrog integrator. Must broadcast with the shape of - `current_state_parts`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. When - possible, it's often helpful to match per-variable step sizes to the - standard deviations of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn(*current_state_parts)`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob_fn(*current_state_parts`) wrt - `current_state_parts`. Must have same shape as `current_state_parts`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_leapfrog_integrator"). - - Returns: - proposed_momentums: Updated value of the momentum. - proposed_state_parts: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state_parts`. - proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `next_state`. - proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `next_state`. - - Raises: - ValueError: if `len(momentums) != len(state_parts)`. - ValueError: if `len(state_parts) != len(step_sizes)`. - ValueError: if `len(state_parts) != len(grads_target_log_prob)`. - TypeError: if `not target_log_prob.dtype.is_floating`. - """ - def _loop_body(step, - current_momentums, - current_state_parts, - ignore_current_target_log_prob, # pylint: disable=unused-argument - current_grads_target_log_prob): - return [step + 1] + list(_leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob)) - - with ops.name_scope( - name, "hmc_leapfrog_integrator", - [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, - current_target_log_prob, current_grads_target_log_prob]): - if len(current_momentums) != len(current_state_parts): - raise ValueError("`momentums` must be in one-to-one correspondence " - "with `state_parts`") - num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, - name="num_leapfrog_steps") - current_target_log_prob, current_grads_target_log_prob = ( - _maybe_call_fn_and_grads( - target_log_prob_fn, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_leapfrog_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - current_momentums, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob, - ], - back_prop=False)[1:] # Lop-off "iter_". - - -def _leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob, - name=None): - """Applies one step of the leapfrog integrator.""" - with ops.name_scope( - name, "_leapfrog_step", - [current_momentums, current_state_parts, step_sizes, - current_grads_target_log_prob]): - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(current_momentums, - step_sizes, - current_grads_target_log_prob)] - proposed_state_parts = [x + ss * m for x, ss, m - in zip(current_state_parts, - step_sizes, - proposed_momentums)] - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - if not proposed_target_log_prob.dtype.is_floating: - raise TypeError("`target_log_prob_fn` must produce a `Tensor` " - "with `float` `dtype`.") - proposed_grads_target_log_prob = gradients_ops.gradients( - proposed_target_log_prob, proposed_state_parts) - if any(g is None for g in proposed_grads_target_log_prob): - raise ValueError( - "Encountered `None` gradient. Does your target `target_log_prob_fn` " - "access all `tf.Variable`s via `tf.get_variable`?\n" - " current_state_parts: {}\n" - " proposed_state_parts: {}\n" - " proposed_grads_target_log_prob: {}".format( - current_state_parts, - proposed_state_parts, - proposed_grads_target_log_prob)) - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(proposed_momentums, - step_sizes, - proposed_grads_target_log_prob)] - return [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] - - -def _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims, - name=None): - """Helper to `kernel` which computes the energy change.""" - with ops.name_scope( - name, "compute_energy_change", - ([current_target_log_prob, proposed_target_log_prob, - independent_chain_ndims] + - current_momentums + proposed_momentums)): - # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy - # since they're a mouthful and lets us inline more. - lk0, lk1 = [], [] - for current_momentum, proposed_momentum in zip(current_momentums, - proposed_momentums): - axis = math_ops.range(independent_chain_ndims, - array_ops.rank(current_momentum)) - lk0.append(_log_sum_sq(current_momentum, axis)) - lk1.append(_log_sum_sq(proposed_momentum, axis)) - - lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), - axis=-1) - lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), - axis=-1) - lp0 = -current_target_log_prob # potential - lp1 = -proposed_target_log_prob # proposed_potential - x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], - axis=-1) - - # The sum is NaN if any element is NaN or we see both +Inf and -Inf. - # Thus we will replace such rows with infinite energy change which implies - # rejection. Recall that float-comparisons with NaN are always False. - is_sum_determinate = ( - math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & - math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) - is_sum_determinate = array_ops.tile( - is_sum_determinate[..., array_ops.newaxis], - multiples=array_ops.concat([ - array_ops.ones(array_ops.rank(is_sum_determinate), - dtype=dtypes.int32), - [4], - ], axis=0)) - x = array_ops.where(is_sum_determinate, - x, - array_ops.fill(array_ops.shape(x), - value=x.dtype.as_numpy_dtype(np.inf))) - - return math_ops.reduce_sum(x, axis=-1) - - -def _choose(is_accepted, - accepted, - rejected, - independent_chain_ndims, - name=None): - """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" - def _expand_is_accepted_like(x): - with ops.name_scope("_choose"): - expand_shape = array_ops.concat([ - array_ops.shape(is_accepted), - array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], - dtype=dtypes.int32), - ], axis=0) - multiples = array_ops.concat([ - array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), - array_ops.shape(x)[independent_chain_ndims:], - ], axis=0) - m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), - multiples) - m.set_shape(x.shape) - return m - with ops.name_scope(name, "_choose", values=[ - is_accepted, accepted, rejected, independent_chain_ndims]): - return array_ops.where(_expand_is_accepted_like(accepted), - accepted, - rejected) - - -def _maybe_call_fn_and_grads(fn, - fn_arg_list, - fn_result=None, - grads_fn_result=None, - description="target_log_prob"): - """Helper which computes `fn_result` and `grads` if needed.""" - fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) - else [fn_arg_list]) - if fn_result is None: - fn_result = fn(*fn_arg_list) - if not fn_result.dtype.is_floating: - raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( - description)) - if grads_fn_result is None: - grads_fn_result = gradients_ops.gradients( - fn_result, fn_arg_list) - if len(fn_arg_list) != len(grads_fn_result): - raise ValueError("`{}` must be in one-to-one correspondence with " - "`grads_{}`".format(*[description]*2)) - if any(g is None for g in grads_fn_result): - raise ValueError("Encountered `None` gradient.") - return fn_result, grads_fn_result - - -def _prepare_args(target_log_prob_fn, state, step_size, - target_log_prob=None, grads_target_log_prob=None, - maybe_expand=False, description="target_log_prob"): - """Helper which processes input args to meet list-like assumptions.""" - state_parts = list(state) if _is_list_like(state) else [state] - state_parts = [ops.convert_to_tensor(s, name="state") - for s in state_parts] - target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( - target_log_prob_fn, - state_parts, - target_log_prob, - grads_target_log_prob, - description) - step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] - step_sizes = [ - ops.convert_to_tensor( - s, name="step_size", dtype=target_log_prob.dtype) - for s in step_sizes] - if len(step_sizes) == 1: - step_sizes *= len(state_parts) - if len(state_parts) != len(step_sizes): - raise ValueError("There should be exactly one `step_size` or it should " - "have same length as `current_state`.") - maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] - return [ - maybe_flatten(state_parts), - maybe_flatten(step_sizes), - target_log_prob, - grads_target_log_prob, - ] - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _log_sum_sq(x, axis=None): - """Computes log(sum(x**2)).""" - return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py deleted file mode 100644 index 05aa134ed5c11092316af5f3e45ba07fdb491e90..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Metropolis-Hastings and proposal distributions. - -@@kernel -@@evolve -@@proposal_uniform -@@proposal_normal -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops - -__all__ = [ - "kernel", - "evolve", - "proposal_uniform", - "proposal_normal", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "log_accept_ratio", - "current_target_log_prob", # "Current result" means "accepted". - "is_accepted", - "proposed_state", - ]) - - -def kernel(target_log_prob_fn, - proposal_fn, - current_state, - seed=None, - current_target_log_prob=None, - name=None): - """Runs the Metropolis-Hastings transition kernel. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - proposal_fn: Python callable which takes an argument like `current_state` - (or `*current_state` if it's a list) and returns a tuple of proposed - states of same shape as `state`, and a log ratio `Tensor` of same shape - as `current_target_log_prob`. The log ratio is the log-probability of - `state` given proposed states minus the log-probability of proposed - states given `state`. If the proposal is symmetric, set the second value - to `None`: this enables more efficient computation than explicitly - supplying a tensor of zeros. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: A name of the operation (optional). - - Returns: - next_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - #### Examples - - We illustrate Metropolis-Hastings on a Normal likelihood with - unknown mean. - - ```python - tfd = tf.contrib.distributions - tfp = tf.contrib.bayesflow - - loc = tf.get_variable("loc", initializer=1.) - x = tf.constant([0.0] * 50) - - def make_target_log_prob_fn(x): - def target_log_prob_fn(loc): - prior = tfd.Normal(loc=0., scale=1.) - likelihood = tfd.Independent( - tfd.Normal(loc=loc, scale=0.1), - reinterpreted_batch_ndims=1) - return prior.log_prob(loc) + likelihood.log_prob(x) - return target_log_prob_fn - - next_state, kernel_results = tfp.metropolis_hastings.kernel( - target_log_prob_fn=make_target_log_prob_fn(x), - proposal_fn=tfp.metropolis_hastings.proposal_normal(), - current_state=loc) - loc_update = loc.assign(next_state) - ``` - - We illustrate Metropolis-Hastings on a Normal likelihood with - unknown mean and variance. We apply 4 chains. - - ```python - tfd = tf.contrib.distributions - tfp = tf.contrib.bayesflow - - num_chains = 4 - loc = tf.get_variable("loc", shape=[num_chains], - initializer=tf.random_normal_initializer()) - scale = tf.get_variable("scale", shape=[num_chains], - initializer=tf.ones_initializer()) - x = tf.constant([0.0] * 50) - - def make_target_log_prob_fn(x): - data = tf.reshape(x, shape=[-1, 1]) - def target_log_prob_fn(loc, scale): - prior_loc = tfd.Normal(loc=0., scale=1.) - prior_scale = tfd.InverseGamma(concentration=1., rate=1.) - likelihood = tfd.Independent( - tfd.Normal(loc=loc, scale=scale), - reinterpreted_batch_ndims=1) - return (prior_loc.log_prob(loc) + - prior_scale.log_prob(scale) + - likelihood.log_prob(data)) - return target_log_prob_fn - - def proposal_fn(loc, scale): - loc_proposal = tfp.metropolis_hastings.proposal_normal() - scale_proposal = tfp.metropolis_hastings.proposal_uniform(minval=-1.) - proposed_loc, _ = loc_proposal(loc) - proposed_scale, _ = scale_proposal(scale) - proposed_scale = tf.maximum(proposed_scale, 0.01) - return [proposed_loc, proposed_scale], None - - next_state, kernel_results = tfp.metropolis_hastings.kernel( - target_log_prob_fn=make_target_log_prob_fn(x), - proposal_fn=proposal_fn, - current_state=[loc, scale]) - train_op = tf.group(loc.assign(next_state[0]), - scale.assign(next_state[1])) - ``` - - """ - with ops.name_scope( - name, "metropolis_hastings_kernel", - [current_state, seed, current_target_log_prob]): - with ops.name_scope("initialize"): - maybe_expand = lambda x: list(x) if _is_list_like(x) else [x] - current_state_parts = maybe_expand(current_state) - if current_target_log_prob is None: - current_target_log_prob = target_log_prob_fn(*current_state_parts) - - proposed_state, log_transit_ratio = proposal_fn(*current_state_parts) - proposed_state_parts = maybe_expand(proposed_state) - - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - - with ops.name_scope( - "accept_reject", - [current_state_parts, proposed_state_parts, - current_target_log_prob, proposed_target_log_prob]): - log_accept_ratio = proposed_target_log_prob - current_target_log_prob - if log_transit_ratio is not None: - # If the log_transit_ratio is None, then assume the proposal is - # symmetric, i.e., - # log p(old | new) - log p(new | old) = 0. - log_accept_ratio += log_transit_ratio - - # u < exp(log_accept_ratio), where u~Uniform[0,1) - # ==> log(u) < log_accept_ratio - random_value = random_ops.random_uniform( - array_ops.shape(log_accept_ratio), - dtype=log_accept_ratio.dtype, - seed=seed) - random_negative = math_ops.log(random_value) - is_accepted = random_negative < log_accept_ratio - next_state_parts = [array_ops.where(is_accepted, - proposed_state_part, - current_state_part) - for proposed_state_part, current_state_part in - zip(proposed_state_parts, current_state_parts)] - accepted_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(next_state_parts), - KernelResults( - log_accept_ratio=log_accept_ratio, - current_target_log_prob=accepted_log_prob, - is_accepted=is_accepted, - proposed_state=maybe_flatten(proposed_state_parts), - ), - ] - - -def evolve(initial_sample, - initial_log_density, - initial_log_accept_ratio, - target_log_prob_fn, - proposal_fn, - n_steps=1, - seed=None, - name=None): - """Performs `n_steps` of the Metropolis-Hastings update. - - Given a probability density function, `f(x)` and a proposal scheme which - generates new points from old, this `Op` returns a tensor - which may be used to generate approximate samples from the target distribution - using the Metropolis-Hastings algorithm. These samples are from a Markov chain - whose equilibrium distribution matches the target distribution. - - The probability distribution may have an unknown normalization constan. - We parameterize the probability density as follows: - - ```none - f(x) = exp(L(x) + constant) - ``` - - Here `L(x)` is any continuous function with an (possibly unknown but finite) - upper bound, i.e. there exists a number beta such that - `L(x)< beta < infinity` for all x. The constant is the normalization needed - to make `f(x)` a probability density (as opposed to just a finite measure). - - Although `initial_sample` can be arbitrary, a poor choice may result in a - slow-to-mix chain. In many cases the best choice is the one that maximizes - the target density, i.e., choose `initial_sample` such that - `f(initial_sample) >= f(x)` for all `x`. - - - If the support of the distribution is a strict subset of R^n (but of non zero - measure), then the unnormalized log-density `L(x)` should return `-infinity` - outside the support domain. This effectively forces the sampler to only - explore points in the regions of finite support. - - Usage: - This function is meant to be wrapped up with some of the common proposal - schemes (e.g. random walk, Langevin diffusion etc) to produce a more user - friendly interface. However, it may also be used to create bespoke samplers. - - The following example, demonstrates the use to generate a 1000 uniform random - walk Metropolis samplers run in parallel for the normal target distribution. - - ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - "state", - initializer=tf.random_normal([1000, n], - mean=3.0, - dtype=tf.float64, - seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return -tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - "state_log_density", - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - "log_acceptance_ratio", - initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print(np.cov(samples, rowvar=False)) - ``` - - Args: - initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `target_log_prob_fn` and `proposal_fn` - callables. - initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `target_log_prob_fn(initial_sample)`, i.e., matching - the result of `target_log_prob_fn` invoked at `current_state`. - initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching - `initial_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio after propagating the chain for `n_steps`. - target_log_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `target_log_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, here `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `target_log_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair should be a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric, i.e. - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to None instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - n_steps: A positive `int` or a scalar `int32` tensor. Sets the number of - iterations of the chain. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - forward_step: an `Op` to step the Markov chain forward for `n_steps`. - """ - - with ops.name_scope(name, "metropolis_hastings", [initial_sample]): - current_state = initial_sample - current_target_log_prob = initial_log_density - log_accept_ratio = initial_log_accept_ratio - - def step(i, current_state, current_target_log_prob, log_accept_ratio): - """Wrap single Markov chain iteration in `while_loop`.""" - next_state, kernel_results = kernel( - target_log_prob_fn=target_log_prob_fn, - proposal_fn=proposal_fn, - current_state=current_state, - current_target_log_prob=current_target_log_prob, - seed=seed) - accepted_log_prob = kernel_results.current_target_log_prob - log_accept_ratio = kernel_results.log_accept_ratio - return i + 1, next_state, accepted_log_prob, log_accept_ratio - - (_, accepted_state, accepted_target_log_prob, accepted_log_accept_ratio) = ( - control_flow_ops.while_loop( - cond=lambda i, *ignored_args: i < n_steps, - body=step, - loop_vars=[ - 0, # i - current_state, - current_target_log_prob, - log_accept_ratio, - ], - parallel_iterations=1 if seed is not None else 10, - # TODO(b/73775595): Confirm optimal setting of swap_memory. - swap_memory=1)) - - forward_step = control_flow_ops.group( - state_ops.assign(current_target_log_prob, accepted_target_log_prob), - state_ops.assign(current_state, accepted_state), - state_ops.assign(log_accept_ratio, accepted_log_accept_ratio)) - - return forward_step - - -def proposal_uniform(step_size=1., - seed=None, - name=None): - """Returns a callable that adds a random uniform tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). It adds a - sample from a random uniform distribution drawn from [-stepsize, stepsize] - to its input. It also returns the log of the ratio of the probability of - moving from the input point to the proposed point, but since this log ratio is - identically equal to 0 (because the probability of drawing a value `x` from - the symmetric uniform distribution is the same as the probability of drawing - `-x`), it simply returns None for the second element of the returned tuple. - - Args: - step_size: A positive `float` or a scalar tensor of real dtype - controlling the scale of the uniform distribution. - If step_size = a, then draws are made uniformly from [-a, a]. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, "proposal_uniform", [step_size]): - step_size = ops.convert_to_tensor(step_size, name="step_size") - - def proposal_fn(input_state, name=None): - """Adds a uniform perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensor` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - with ops.name_scope(name, "proposer", [input_state]): - input_state = ops.convert_to_tensor(input_state, name="input_state") - return input_state + random_ops.random_uniform( - array_ops.shape(input_state), - minval=-step_size, - maxval=step_size, - seed=seed), None - return proposal_fn - - -def proposal_normal(scale=1., - seed=None, - name=None): - """Returns a callable that adds a random normal tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). The callable - adds a sample from a normal distribution with the supplied standard deviation - and zero mean to its input argument (called the proposal point). - The callable returns a tuple with the proposal point as the first element. - The second element is identically `None`. It is included so the callable is - compatible with the expected signature of the proposal scheme argument in the - `metropolis_hastings` function. A value of `None` indicates that the - probability of going from the input point to the proposal point is equal to - the probability of going from the proposal point to the input point. - - Args: - scale: A positive `float` or a scalar tensor of any real dtype controlling - the scale of the normal distribution. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, "proposal_normal", [scale]): - scale = ops.convert_to_tensor(scale, name="scale") - - def proposal_fn(input_state, name=None): - """Adds a normal perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensor` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - - with ops.name_scope(name, "proposer", [input_state]): - input_state = ops.convert_to_tensor(input_state, name="input_state") - return input_state + random_ops.random_normal( - array_ops.shape(input_state), - mean=0., - stddev=scale, - dtype=scale.dtype, - seed=seed), None - return proposal_fn - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 985177e897f443989e466d1a498c461a30aeb5cb..d193a8459d00b83580509c8de25d5f7801b195fe 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -44,14 +44,14 @@ def expectation_importance_sampler(f, n=None, seed=None, name='expectation_importance_sampler'): - r"""Monte Carlo estimate of `E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]`. + r"""Monte Carlo estimate of `\\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\)`. - With `p(z) := exp{log_p(z)}`, this `Op` returns + With `\\(p(z) := exp^{log_p(z)}\\)`, this `Op` returns ``` - n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q, - \approx E_q[ f(Z) p(Z) / q(Z) ] - = E_p[f(Z)] + \\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\) + \\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\) + \\(= E_p[f(Z)]\\) ``` This integral is done in log-space with max-subtraction to better handle the @@ -95,9 +95,9 @@ def expectation_importance_sampler(f, log_values = log_f_z + log_p_z - q_log_prob_z return _logspace_mean(log_values) - # With f_plus(z) = max(0, f(z)), f_minus(z) = max(0, -f(z)), - # E_p[f(Z)] = E_p[f_plus(Z)] - E_p[f_minus(Z)] - # = E_p[f_plus(Z) + 1] - E_p[f_minus(Z) + 1] + # With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\), + # \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\) + # \\( = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\) # Without incurring bias, 1 is added to each to prevent zeros in logspace. # The logarithm is approximately linear around 1 + epsilon, so this is good # for small values of 'z' as well. @@ -121,13 +121,13 @@ def expectation_importance_sampler_logspace( name='expectation_importance_sampler_logspace'): r"""Importance sampling with a positive function, in log-space. - With `p(z) := exp{log_p(z)}`, and `f(z) = exp{log_f(z)}`, this `Op` - returns + With `\\(p(z) := exp^{log_p(z)}\\)`, and `\\(f(z) = exp{log_f(z)}\\)`, + this `Op` returns ``` - Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q, - \approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ] - = Log[E_p[f(Z)]] + \\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\) + \\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\) + \\(= Log[E_p[f(Z)]]\\) ``` This integral is done in log-space with max-subtraction to better handle the @@ -196,12 +196,12 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of `E_p[f(X)]`. + """Computes the Monte-Carlo approximation of `\\(E_p[f(X)]\\)`. This function computes the Monte-Carlo approximation of an expectation, i.e., ```none - E_p[f(X)] approx= m**-1 sum_i^m f(x_j), x_j ~iid p(X) + \\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\) ``` where: @@ -216,8 +216,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `grad[ Avg{ s_i : i=1...n } ] = Avg{ grad[s_i] : i=1...n }` where - `S_n = Avg{s_i}` and `s_i = f(x_i), x_i ~ p`. + `grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n }` where + `S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\)`. However, if p is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of non-reparameterized distributions. @@ -296,7 +296,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Args: f: Python callable which can return `f(samples)`. samples: `Tensor` of samples used to form the Monte-Carlo approximation of - `E_p[f(X)]`. A batch of samples should be indexed by `axis` dimensions. + `\\(E_p[f(X)]\\)`. A batch of samples should be indexed by `axis` + dimensions. log_prob: Python callable which can return `log_prob(samples)`. Must correspond to the natural-logarithm of the pdf/pmf of each sample. Only required/used if `use_reparametrization=False`. @@ -316,7 +317,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Returns: approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation - of `E_p[f(X)]`. + of `\\(E_p[f(X)]\\)`. Raises: ValueError: if `f` is not a Python `callable`. diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 6fdcd0f996ee011842a5add79f06264a28a2145c..8eac1243ef63dd09c5c5dad4bcd9bd7a15f58900 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -14,15 +14,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = ["**/OWNERS"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - package_group(name = "friends") cc_library( @@ -128,7 +119,7 @@ py_library( py_test( name = "gbdt_batch_test", - size = "small", + size = "medium", srcs = ["python/training/functions/gbdt_batch_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 289f5bb3140974d8c37f4938ceef27275b099f9a..17e20c4b315bab8852c90788567a2f2f92119f40 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -10,23 +10,17 @@ package( load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", + deps = [ + "custom_export_strategy", + ":custom_loss_head", + ":estimator", + ":model", + ":trainer_hooks", + ], ) py_library( @@ -149,7 +143,7 @@ py_library( py_test( name = "dnn_tree_combined_estimator_test", - size = "small", + size = "medium", srcs = ["dnn_tree_combined_estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 23ba76210b3b68d0d0b2eef9d4040882654bdad9..d9b0d89a03dce40d34f76bb1262d26bb587a2dc7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -54,7 +54,7 @@ def make_custom_export_strategy(name, An `ExportStrategy`. """ base_strategy = saved_model_export_utils.make_export_strategy( - serving_input_fn=export_input_fn) + serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index cec3892b57655dc967b4e7926f7f5a6a30084487..2e7b8cba05b89feaac3f47e13d26e7ae37a7b0ae 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -25,15 +25,20 @@ from __future__ import division from __future__ import print_function import six - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -46,6 +51,52 @@ from tensorflow.python.training import training_util _DNN_LEARNING_RATE = 0.001 +_CORE_MODE_TO_CONTRIB_MODE_ = { + model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN, + model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER +} + + +def _core_mode_to_contrib_mode(mode): + return _CORE_MODE_TO_CONTRIB_MODE_[mode] + + +def _export_outputs_to_output_alternatives(export_outputs): + """Converts EstimatorSpec.export_outputs to output_alternatives. + + Args: + export_outputs: export_outputs created by create_estimator_spec. + Returns: + converted output_alternatives. + """ + output = dict() + if export_outputs is not None: + for key, value in export_outputs.items(): + if isinstance(value, export_output.ClassificationOutput): + exported_predictions = { + prediction_key.PredictionKey.SCORES: value.scores, + prediction_key.PredictionKey.CLASSES: value.classes + } + output[key] = (constants.ProblemType.CLASSIFICATION, + exported_predictions) + return output + return None + + +def _estimator_spec_to_model_fn_ops(estimator_spec, is_regression): + alternatives = [] + if not is_regression: + _export_outputs_to_output_alternatives(estimator_spec.export_outputs) + + return model_fn.ModelFnOps( + mode=_core_mode_to_contrib_mode(estimator_spec.mode), + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + train_op=estimator_spec.train_op, + eval_metric_ops=estimator_spec.eval_metric_ops, + output_alternatives=alternatives) + def _get_optimizer(optimizer): if callable(optimizer): @@ -59,16 +110,26 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, - dnn_feature_columns, tree_learner_config, num_trees, - tree_examples_per_layer, - config=None, dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, - tree_feature_columns=None, - tree_center_bias=True): +def _dnn_tree_combined_model_fn(features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=False, + use_core_versions=False, + is_regression=False): """DNN and GBDT combined model_fn. Args: @@ -106,6 +167,9 @@ def _dnn_tree_combined_model_fn( set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + is_regression: Whether the problem is regression or not. Returns: A `ModelFnOps` object. @@ -135,11 +199,17 @@ def _dnn_tree_combined_model_fn( "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=dnn_partitioner) as input_layer_scope: - input_layer = layers.input_from_feature_columns( - columns_to_tensors=features, - feature_columns=dnn_feature_columns, - weight_collections=[dnn_parent_scope], - scope=input_layer_scope) + if use_core_versions: + input_layer = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope]) + else: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) previous_layer = input_layer for layer_id, num_hidden_units in enumerate(dnn_hidden_units): with variable_scope.variable_scope( @@ -222,24 +292,51 @@ def _dnn_tree_combined_model_fn( del loss return control_flow_ops.no_op() - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = _estimator_spec_to_model_fn_ops(dnn_train_op, + is_regression).train_op + + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = _estimator_spec_to_model_fn_ops(tree_train_op, + is_regression).train_op + + model_fn_ops = _estimator_spec_to_model_fn_ops(model_fn_ops, is_regression) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op if tree_center_bias: num_trees += 1 @@ -277,7 +374,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -322,6 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -336,8 +436,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -366,7 +466,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -411,6 +512,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.regression_head( label_name=label_name, @@ -426,11 +529,26 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias) + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config, + dnn_optimizer, + dnn_activation_fn, + dnn_dropout, + dnn_input_layer_partitioner, + dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, + tree_center_bias, + use_core_versions, + is_regression=True) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -460,7 +578,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -500,6 +619,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( @@ -507,8 +628,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 83d58c561008e8a5a69eb503d1605bb9e940f281..f495edc62f0909880c170ccb4cf5d11e3f20f55c 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -19,15 +19,17 @@ from __future__ import division from __future__ import print_function import tempfile - from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest @@ -100,6 +102,35 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateDontThrowExceptionWithCore(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + classifier = estimator.DNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + # Use core feature columns + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[], + use_core_versions=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 01752416b347dd0a5e646283b6b5572592df4690..70454aa6dbdb19297028a3f80822719bef5a0f72 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -81,7 +81,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=False, - loss_fn=loss_fn) + loss_fn=loss_fn, + label_keys=label_keys) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0f4c2298f56be48bb32f52d5d44cff8afe284f1e..0b28f81e7ca9a1228adc5bde19c429265e0aa9b8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -253,7 +253,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { private: float epsilon_; int32 num_quantiles_; - // An upperbound on the number of enteries that the summaries might have + // An upper bound on the number of entries that the summaries might have // for a feature. int64 max_elements_; bool generate_quantiles_; diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 131bd48562a55a08981ac73277e93024db0d85d3..3028c2281705bd7e34b212332160d25386559d4e 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -15,17 +15,6 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # Utils cc_library( diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index cf4f9a097a3368465fd4d9afb981bbaa68b4df49..35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -54,7 +54,7 @@ Status BatchFeatures::Initialize( TF_CHECK_AND_RETURN_IF_ERROR( dense_float_feature.dim_size(1) == 1, errors::InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = ", + "Dense float features may not be multivalent: dim_size(1) = ", dense_float_feature.dim_size(1))); dense_float_feature_columns_.emplace_back(dense_float_feature); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index 7815fa049aa165a944c45872c762b7a5bf91b316..a3b1b013e3a40116f74d6ed2df78d87ed3a11ac7 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -48,9 +48,9 @@ class BatchFeatures { Status GetFeatureColumnSizes(int64* const num_dense_float_features, int64* const num_sparse_float_features, int64* const num_sparse_int_features) const { - QCHECK_NE(num_dense_float_features, (int64*) nullptr); - QCHECK_NE(num_sparse_float_features, (int64*) nullptr); - QCHECK_NE(num_sparse_int_features, (int64*) nullptr); + QCHECK_NE(num_dense_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_int_features, static_cast(nullptr)); *num_dense_float_features = dense_float_feature_columns_.size(); *num_sparse_float_features = sparse_float_feature_columns_.size(); *num_sparse_int_features = sparse_int_feature_columns_.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 609519e8b1153a27d987c5f9ca9bfcc9ee6717d6..cfe9101e7435cd798569f3e52a87fc8ed7b6a239 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -59,7 +59,7 @@ TEST_F(BatchFeaturesTest, DenseFloatFeatures_Multivalent) { BatchFeatures batch_features(1); auto dense_vec = AsTensor({3.0f, 7.0f}, {1, 2}); auto expected_error = InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = 2"); + "Dense float features may not be multivalent: dim_size(1) = 2"); EXPECT_EQ(expected_error, batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {})); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index db34db998a7442c69f2ab468f4557d991429f4ee..ce67db797ded54f5023eaa89369d4781aad31a7c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -54,7 +54,7 @@ Status DropoutUtils::DropOutTrees( if (probability_of_skipping_dropout < 0 || probability_of_skipping_dropout > 1) { return errors::InvalidArgument( - "Probability of skiping dropout must be in [0,1] range"); + "Probability of skipping dropout must be in [0,1] range"); } const auto num_trees = weights.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index 928bfbfe5c9394ab4083aabced4c8e1149bb10aa..77c16da5410fe65b20839c7b6bc677067d7ff297 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -66,7 +66,7 @@ class DropoutUtils { // Current weights and num_updates will be updated as a result of this // func std::vector* current_weights, - // How many weight assignements have been done for each tree already. + // How many weight assignments have been done for each tree already. std::vector* num_updates); }; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 0138aae3dbd3773241cb6644db625b99f9bf1372..cc7604745e6bb90837eeca1123faa88dc914e4fc 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,7 +34,7 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multivalent. // This can be visualized like the following: // Instance | Sparse | // 0 | x | diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index 9a61e163eb5ff51dc75de4e40e0f43b090d03c0c..b07f0a4314246eea63764bb6d5e166dd720644fb 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -4,17 +4,6 @@ exports_files(["LICENSE"]) load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "learner_proto", srcs = [ diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 4407c4d981785a279b6296f4726a221cacb4c5b1..81411aa84ae848cfaa1392e82a1e38c3df19cdb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,7 +53,7 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the dimensiong + // If feature column is multivalent, this holds the index of the dimension // for the split. Defaults to 0. int32 dimension_id = 5; float threshold = 2; diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf..cf55759aaabfb265466f4bbf8b2806d4347ca0b1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -120,8 +120,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): """Sets up the prediction tests. Create a batch of two examples having one dense float, two sparse float - single valued, one sparse float multidimensionl and one sparse int features. - The data looks like the following: + single valued, one sparse float multidimensional and one sparse int + features. The data looks like the following: | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM | 0 | 7 | -3 | | 9,1 | __, 5.0 | 1 | -2 | | 4 | | 3, ___ @@ -810,7 +810,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # building. This tree should never be dropped. num_trees = 10 with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, num_trees): @@ -951,7 +951,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testDropOutZeroProb(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 1000 trees with some weights. for i in range(0, 999): @@ -994,7 +994,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testAveragingAllTrees(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() adjusted_tree_ensemble_config = ( tree_config_pb2.DecisionTreeEnsembleConfig()) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 81f58de28cbe98bb996c6665114eeb0030ee52f9..074623699d9d82f999c9cbc483ddcd8a959f4bad 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -482,7 +482,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): """Sets up the quantile op tests. Create a batch of 4 examples having 2 dense and 4 sparse features. - Forth sparse feature is multivalent (3 dimensional) + Fourth sparse feature is multivalent (3 dimensional) The data looks like this | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 97d57e8b23608d4c3a8719426a75056fc6417d1d..1b184d296b329cee481db67992e77d1e33e18035 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -184,7 +184,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Finalizes quantile summary stream and resets it for next iteration. Args: - stamp_token: Exepcted current token. + stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: A list of quantiles or approximate boundaries. diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD index 9fc101612f1e2a6bf6c5d86ea8c7199936dbb069..c0651868453d40d57e842862855f89e6845c507f 100644 --- a/tensorflow/contrib/boosted_trees/resources/BUILD +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -9,17 +9,6 @@ package( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "stamped_resource", hdrs = ["stamped_resource.h"], diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index fe8bd072afd43a64fa62a65bd8900b5a98dbe761..f3a75e8688ece19a6e6fd53ee9faf7f4144d76cf 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -14,18 +14,6 @@ load( "tf_py_test", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_gen_op_libs( op_lib_names = ["bigquery_reader_ops"], deps = [ diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 56f930a9a8d32c5c3a025163ef56c9562f17d864..ff46f0daa80a70badedf73e15bfaf4dca85fdd89 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -20,20 +20,6 @@ load( "tf_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_kernel_library( name = "bigquery_reader_ops", srcs = ["bigquery_reader_ops.cc"], @@ -73,6 +59,7 @@ tf_cc_test( ], deps = [ ":bigquery_table_accessor", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index e9b79a066def566096d6c3f3745974423e3371d1..7416eb19d3324fad84876cde5353bc25bac8f648 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/test.h" @@ -28,8 +29,8 @@ constexpr char kTestProject[] = "test-project"; constexpr char kTestDataset[] = "test-dataset"; constexpr char kTestTable[] = "test-table"; -bool HasSubstr(const string& base, const string& substr) { - bool ok = StringPiece(base).contains(substr); +bool HasSubstr(StringPiece base, StringPiece substr) { + bool ok = str_util::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 1a124eca364424b651de86bfaac6f33ad131804b..c239e6f8f960910cee14e1df7c4678c643496f54 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -10,19 +10,6 @@ package( licenses(["notice"]) # Apache 2.0 -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - py_library( name = "cluster_resolver_pip", srcs = [ diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 300b19733e2b4d1b912f966e94ae0286ed9c694d..95c5c920aa2ccf92d8aa6aa179102fe379f0236c 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -73,7 +73,7 @@ class TPUClusterResolver(ClusterResolver): zone=None, project=None, job_name='worker', - coordinator_name='coordinator', + coordinator_name=None, coordinator_address=None, credentials='default', service=None): diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 48c3f6bb4f2d1643982e03d9ed68db14c10c184a..e1e3e6867a24b917885a9ab7e780df55742ec0f9 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -117,7 +117,8 @@ class TPUClusterResolverTest(test.TestCase): zone=None, tpu=['test-tpu-1'], credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ @@ -170,6 +171,7 @@ class TPUClusterResolverTest(test.TestCase): project='test-project', zone='us-central1-c', tpu=['test-tpu-1'], + coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) @@ -196,6 +198,7 @@ class TPUClusterResolverTest(test.TestCase): project='test-project', zone='us-central1-c', tpu='test-tpu-1', + coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) @@ -239,7 +242,8 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( tpu='test-tpu-1', credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 17f65999faaf5c0ca39bfbc968a9140dbff49c2e..1fefb731a775d9cd2478cbb654662ec6ba673fed 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) +set(GRPC_TAG bd6bdf93279a39a8cd92978fd7c9d14eccd98fc2) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -35,6 +35,7 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index f3a37ff5088e3f9e54e38c0edb5777c27b26969f..b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) +set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index aaae18a313dd082b428654091c9411600c981ec9..6f059c7225dd0938b758e8f9c28ec36fcff6db4c 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -42,7 +42,6 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11") add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11") set (NSYNC_OS_CPP_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/c++11/src/per_thread_waiter.cc" "platform/c++11/src/yield.cc" "platform/c++11/src/time_rep_timespec.cc" @@ -52,6 +51,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") add_compile_options ("/TP") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/win32/src/clock_gettime.c" "platform/win32/src/pthread_key_win32.cc" ${NSYNC_OS_CPP_SRC} @@ -68,6 +68,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/posix/src/clock_gettime.c" "platform/posix/src/nsync_semaphore_mutex.c" ) @@ -75,9 +76,11 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") "platform/posix/src/start_thread.c" ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") + include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/linux/src/nsync_semaphore_futex.c" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -87,6 +90,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -96,6 +100,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -105,6 +110,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 0d2a6a23db26af2fb9498849aa93e74379915fe3..f273c7e5508e10407d013acd7adc08c732322841 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -79,9 +79,11 @@ tensorflow/python/keras/_impl/keras/preprocessing tensorflow/python/keras/_impl/keras/utils tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions tensorflow/python/kernel_tests/linalg tensorflow/python/kernel_tests/random +tensorflow/python/kernel_tests/testdata tensorflow/python/layers tensorflow/python/lib tensorflow/python/lib/core @@ -147,8 +149,6 @@ tensorflow/contrib/crf tensorflow/contrib/crf/python tensorflow/contrib/crf/python/ops tensorflow/contrib/cudnn_rnn -tensorflow/contrib/cudnn_rnn/kernels -tensorflow/contrib/cudnn_rnn/ops tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops @@ -160,6 +160,9 @@ tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto tensorflow/contrib/deprecated +tensorflow/contrib/distribute +tensorflow/contrib/distribute/python +tensorflow/contrib/distribute/python/examples tensorflow/contrib/distributions tensorflow/contrib/distributions/python tensorflow/contrib/distributions/python/ops @@ -332,6 +335,7 @@ tensorflow/contrib/nccl/kernels tensorflow/contrib/nccl/ops tensorflow/contrib/nccl/python tensorflow/contrib/nccl/python/ops +tensorflow/contrib/nearest_neighbor tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops tensorflow/contrib/nearest_neighbor/python @@ -342,6 +346,7 @@ tensorflow/contrib/nn/python/ops tensorflow/contrib/opt tensorflow/contrib/opt/python tensorflow/contrib/opt/python/training +tensorflow/contrib/optimizer_v2 tensorflow/contrib/pi_examples tensorflow/contrib/pi_examples/camera tensorflow/contrib/pi_examples/label_image diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index c03c0c80fe62a4f95d0fcf240ee25725a19d86f0..0c80d529af5230ed6d36b265e12ee4b749a14ec4 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,4 +1,5 @@ tensorflow/core +tensorflow/core/kernels/boosted_trees tensorflow/core/profiler tensorflow/python tensorflow/contrib/boosted_trees/proto diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 998f99ecc19f88921dce14fde892912fb699ad08..ed018b4fed8e47632f632723f19cc755f2079f86 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -67,8 +67,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 59e094812aaf4da2549d96314fc550e5635f9de8..092a48bc6b63503be39343a1f936875082490b3e 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -15,19 +15,21 @@ set(tf_op_lib_names "audio_ops" "array_ops" - "batch_ops" + "batch_ops" "bitwise_ops" + "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" "control_flow_ops" "ctc_ops" + "cudnn_rnn_ops" "data_flow_ops" "dataset_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" - "list_ops" + "list_ops" "lookup_ops" "logging_ops" "manip_ops" @@ -47,7 +49,7 @@ set(tf_op_lib_names "state_ops" "stateless_random_ops" "string_ops" - "summary_ops" + "summary_ops" "training_ops" ) @@ -84,7 +86,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b730ebd3baacafe8ae401e8987104f3062372954..fae45ead5cafcb0f55834af223555f6e65f16015 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -319,6 +319,7 @@ GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") +GENERATE_PYTHON_OP_LIB("boosted_trees_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") @@ -326,6 +327,7 @@ GENERATE_PYTHON_OP_LIB("checkpoint_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") +GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") GENERATE_PYTHON_OP_LIB("image_ops") @@ -348,6 +350,7 @@ GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") GENERATE_PYTHON_OP_LIB("string_ops") +GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) @@ -366,8 +369,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" @@ -419,8 +420,6 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) -GENERATE_PYTHON_OP_LIB("summary_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) @@ -475,6 +474,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 6d36d5fc5c2854b2d7d2542a3cb12e033e193b88..9738bbeb9aebaeb67495127528e26634887d392c 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -100,8 +100,7 @@ if(WIN32) endif(WIN32) target_include_directories(tensorflow PUBLIC - $ - $) + $) install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin @@ -133,10 +132,6 @@ install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ DESTINATION include/google FILES_MATCHING PATTERN "*.h") -# nsync headers -install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ - DESTINATION include/external/nsync - FILES_MATCHING PATTERN "*.h") # Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ DESTINATION include/Eigen) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index b3e5b30826097d6c747245fec975fcbea3785d15..92f2ab6dea8e7da5dd8481639eda24e31c06848f 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -195,9 +195,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" # Loading resources in contrib doesn't seem to work on Windows @@ -208,6 +210,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py" # Test is flaky on Windows GPU builds (b/38283730). "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" + # Disable following manual tag in BUILD. + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + ) if (WIN32) set(tf_test_src_py_exclude @@ -279,6 +284,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py" # Segfaults on Windows. # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index ec3d550b70d2aaa23b989c44f3d86fa87cffb335..ce12e38248785987e51befa47d04143e235554fe 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -154,14 +154,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 388d8e6ed6d9cb9400b0bfbe8e3f50b80149ea1a..bcee0b04c8430588c2dcbc199504bede0436f8f1 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -46,15 +46,3 @@ cuda_py_test( ], xla_enabled = True, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD index 8ec706df74e2c91345c4bf7a506fdb424a996773..fa44c4d54e1ee871feb425115525b1cf8b732214 100644 --- a/tensorflow/contrib/copy_graph/BUILD +++ b/tensorflow/contrib/copy_graph/BUILD @@ -41,15 +41,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index b806799202bff4f2f6dbf717fbeea74a04b8cd6e..102bc460fdadb0ad5dc9a2960b8655c55357108e 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -201,7 +201,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it #stores String-based info such as name, device and type of the op. #Unique to every Operation instance. - new_node_def = deepcopy(op._node_def) + new_node_def = deepcopy(op.node_def) #Change the name new_node_def.name = new_name @@ -211,7 +211,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #Make a copy of the op_def too. #Its unique to every _type_ of Operation. - op_def = deepcopy(op._op_def) + op_def = deepcopy(op.op_def) #Initialize a new Operation instance new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD index 7aad4abdb908d0284b85137bff842bd0f38d09c6..5c1a17df4f95f3c4d05b286de0e3d7b009a76bd7 100644 --- a/tensorflow/contrib/crf/BUILD +++ b/tensorflow/contrib/crf/BUILD @@ -40,15 +40,3 @@ cuda_py_tests( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fec358c4e1067dc8dc8173d1b9d05dc90b90ca05..8b5d13f72555516babc4250fd934c55adc3d1b8b 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -tf_custom_op_library( - name = "python/ops/_cudnn_rnn_ops.so", - srcs = [ - "kernels/cudnn_rnn_ops.cc", - "ops/cudnn_rnn_ops.cc", - ], - deps = [ - "//tensorflow/core/kernels:bounds_check_lib", - "@farmhash_archive//:farmhash", - ], -) - -tf_kernel_library( - name = "cudnn_rnn_kernels", - srcs = ["kernels/cudnn_rnn_ops.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - "//third_party/eigen3", - "@farmhash_archive//:farmhash", - ], -) - -tf_gen_op_libs( - op_lib_names = ["cudnn_rnn_ops"], - deps = [ - "//tensorflow/core:lib", - ], -) - -tf_gen_op_wrapper_py( - name = "cudnn_rnn_ops", - deps = [":cudnn_rnn_ops_op_lib"], -) tf_custom_op_py_library( name = "cudnn_rnn_py", @@ -64,20 +22,13 @@ tf_custom_op_py_library( "python/layers/cudnn_rnn.py", "python/ops/cudnn_rnn_ops.py", ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cudnn_rnn_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:cudnn_rnn_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", @@ -172,32 +123,3 @@ cuda_py_test( "requires_cudnn5", ], ) - -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":cudnn_rnn_ops_op_lib", - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index e87162f0ee9cc4eed795555171f55a93639e83cf..2ac94424061a07e5727a98642aa855222c0afb81 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -17,27 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -_cudnn_rnn_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) - CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" CUDNN_LSTM = "lstm" @@ -91,19 +86,23 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python - r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate - h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 + # reset gate + $$r_t = \sigma(x_t * W_r + h_t-1 * R_h + b_{Wr} + b_{Rr})$$ + # update gate + $$u_t = \sigma(x_t * W_u + h_t-1 * R_u + b_{Wu} + b_{Ru})$$ + # new memory gate + $$h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_{Rh}) + b_{Wh})$$ + $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): ```python - h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_Wh) # new memory gate + # new memory gate + \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) ``` which is not equivalent to Cudnn GRU: in addition to the extra bias term b_Rh, ```python - r .* (h * R) != (r .* h) * R + \\(r .* (h * R) != (r .* h) * R\\) ``` """ diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 0458199ff771bc45603106411550a39448e515b8..7bb0dc1c0f695f4d1c7739fa11764ded4ff9410a 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -8,6 +8,11 @@ load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", + "if_not_windows", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", ) py_library( @@ -17,6 +22,7 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", @@ -26,26 +32,21 @@ py_library( ], ) +cc_library( + name = "lib_proto_parsing_for_dataset_ops", + deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), +) + tf_custom_op_library( name = "_dataset_ops.so", srcs = ["ops/dataset_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + + if_static( + extra_deps = [":lib_proto_parsing_for_dataset_ops"], + otherwise = [], + ), ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) - -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 9212b69700941c190df1d44ed308147105c56fba..125260b4c1f6b63c8f83f28d1829afe2d9d3ea97 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,6 +25,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter @@SqlDataset +@@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length @@dense_to_sparse_batch @@ -32,10 +33,12 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@group_by_window @@ignore_errors @@make_batched_features_dataset +@@make_csv_dataset @@make_saveable_from_iterator @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@prefetch_to_device @@read_batch_features @@rejection_resample @@scan @@ -53,6 +56,7 @@ from __future__ import print_function # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops.batching import assert_element_shape from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import map_and_batch @@ -67,7 +71,9 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset +from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample @@ -80,3 +86,6 @@ from tensorflow.python.ops.parsing_ops import parse_single_example_v2 as parse_s from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) + +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index c87da7dfaa5943f7918c370f63362673844c7f0e..83ada6fb67dcbff595a38ce9e8609bdd1219b075 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -61,14 +61,3 @@ cc_library( "@protobuf_archive//:protobuf_headers", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 1baac3ea5239659e65881e5b2dea4fe1a8c49d1b..a2bfce03620a1482f5b21cbf23c66833bc5cd480 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -40,8 +40,7 @@ class FunctionBufferingResource : public ResourceBase { const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args, - int64 thread_pool_size) + const std::vector& func_args) : lib_(lib), pflr_(std::move(pflr)), func_(func), @@ -49,27 +48,13 @@ class FunctionBufferingResource : public ResourceBase { source_device_(source_device), target_device_(target_device), func_args_(func_args), - thread_pool_(new thread::ThreadPool(Env::Default(), ThreadOptions(), - "buffer_resource", thread_pool_size, - false /* low_latency_hint */)), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), - cancelled_(false) { - runner_ = [this](std::function c) { - thread_pool_->Schedule(std::move(c)); - }; - } + cancelled_(false) {} ~FunctionBufferingResource() override { Cancel(); - { - mutex_lock l(mu_); - while (is_buffering_) { - cond_var_.wait(l); - } - } - delete thread_pool_; } string DebugString() override { @@ -103,6 +88,20 @@ class FunctionBufferingResource : public ResourceBase { void Cancel() LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; } // If the buffer has anything, runs `callback` on the first element in the @@ -167,15 +166,12 @@ class FunctionBufferingResource : public ResourceBase { for (int i = 0; i < cancellation_callbacks.size(); ++i) { cancellation_callbacks[i](cancellation_buffer_elements[i]); } - // We only wait on cond_var_ in the destructor, so there would atmost be - // one waiter to notify. - cond_var_.notify_one(); + cond_var_.notify_all(); return; } FunctionLibraryRuntime::Options opts; // Copied from CapturedFunction::generate_step_id(); opts.step_id = -std::abs(static_cast(random::New64())); - opts.runner = &runner_; opts.source_device = source_device_; AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); @@ -194,13 +190,12 @@ class FunctionBufferingResource : public ResourceBase { mutex_lock l(mu_); BufferElement buffer_element; buffer_element.status = status; - if (!status.ok()) { + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { end_of_sequence_ = true; is_buffering_ = false; - buffer_.push_back(std::move(buffer_element)); - return; } - buffer_element.value.swap(*rets); buffer_.push_back(std::move(buffer_element)); if (!requests_.empty()) { buffer_front = std::move(buffer_.front()); @@ -208,9 +203,16 @@ class FunctionBufferingResource : public ResourceBase { callback = std::move(requests_.front()); requests_.pop_front(); } - if (buffer_.size() < buffer_size_) { + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { restart_buffering = true; } else { + // When the buffer is full, we don't want to call + // FillBuffer() unless we're in cancellation phase in which + // case FillBuffer() will do the final cleanup post + // cancellation. + if (cancelled_) { + restart_buffering = true; + } is_buffering_ = false; } } @@ -231,11 +233,9 @@ class FunctionBufferingResource : public ResourceBase { const string source_device_; const string target_device_; const std::vector func_args_; - thread::ThreadPool* thread_pool_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); - std::function)> runner_ = nullptr; bool is_buffering_ GUARDED_BY(mu_); bool end_of_sequence_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_); @@ -250,7 +250,6 @@ class FunctionBufferResourceHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_)); } ~FunctionBufferResourceHandleOp() override { @@ -298,9 +297,10 @@ class FunctionBufferResourceHandleOp : public OpKernel { this](FunctionBufferingResource** ptr) { *ptr = new FunctionBufferingResource( clone_lib, std::move(pflr), func_, buffer_size_, - source_device, target_device, func_args, thread_pool_size_); + source_device, target_device, func_args); return Status::OK(); })); + core::ScopedUnref s(buffer); OP_REQUIRES_OK(ctx, buffer->Instantiate()); initialized_ = true; } @@ -319,7 +319,6 @@ class FunctionBufferResourceHandleOp : public OpKernel { int64 buffer_size_; string container_; string name_; - int64 thread_pool_size_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") @@ -360,25 +359,27 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, handle, &buffer), done); - core::ScopedUnref s(buffer); if (buffer->Finished()) { + buffer->Unref(); ctx->SetStatus(errors::OutOfRange("end_of_sequence")); done(); return; } FunctionBufferCallback callback = - [ctx, done](const BufferElement& buffer_element) { + [ctx, buffer, done](const BufferElement& buffer_element) { Status s = buffer_element.status; if (!s.ok()) { ctx->SetStatus(s); + buffer->Unref(); done(); return; } for (size_t i = 0; i < buffer_element.value.size(); ++i) { ctx->set_output(i, buffer_element.value[i]); } + buffer->Unref(); done(); }; buffer->MaybeGet(std::move(callback)); @@ -400,4 +401,62 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") FunctionBufferingResourceGetNextOp); #endif // TENSORFLOW_USE_SYCL +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index a4c1212da11a2410461a120ed5f7116e80e4b903..cf0a8bbccb5813c799e7e6db91d73e2ecf4107f8 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -37,6 +37,14 @@ REGISTER_OP("UniqueDataset") Creates a dataset that contains the unique elements of `input_dataset`. )doc"); +REGISTER_OP("IteratorGetDevice") + .Input("resource: resource") + .Output("device: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Returns the name of the device on which `resource` has been placed. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") @@ -45,7 +53,6 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") - .Attr("thread_pool_size: int") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -55,7 +62,6 @@ target_device: Target device to execute the function on. resource: Handle to the resource created. f: Function to be executed. buffer_size: Size of the buffer. -thread_pool_size: Size of the threadpool doing the prefetching. container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name @@ -75,6 +81,15 @@ output: A list of return values. output_types: The type list for the return values. )doc"); +REGISTER_OP("FunctionBufferingResourceReset") + .Input("function_buffer_resource: resource") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Resets the FunctionBufferingResource. + +function_buffer_resource: The FunctionBufferingResource handle. +)doc"); + REGISTER_OP("ThreadPoolDataset") .Input("input_dataset: variant") .Input("thread_pool: resource") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2c4d4adfdad6d2b3268896cb91cd0357b2b814d9..7270d533c69002ad6b318645f1ef07ebb45a85c3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -22,6 +22,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", @@ -294,9 +295,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", - "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", - "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -479,10 +478,6 @@ py_test( size = "small", srcs = ["prefetching_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "manual", - "no_oss", # b/68785503 - ], deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", @@ -514,17 +509,3 @@ tf_py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 71dc1c1172c9d515d4c85f85257c952135098329..413d8737978b695ac443c92036d6641e5c73f28c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -28,8 +28,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -311,10 +313,10 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> - # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) @@ -381,11 +383,51 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testBatchAndMapDataset(self): - return self._testBatchAndMapDatasetHelper() + def testMapAndBatchDataset(self): + return self._testMapAndBatchDatasetHelper() - def testBatchAndMapDatasetWithParallelBatching(self): - return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchDatasetWithParallelBatching(self): + return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + iterator = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), + batch_size=4, + drop_remainder=drop_remainder)).make_one_shot_iterator()) + if drop_remainder: + self.assertEqual([4, 1], iterator.output_shapes.as_list()) + else: + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + if not drop_remainder: + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchPartialBatch(self): + return self._testMapAndBatchPartialBatchHelper() + + def testMapAndBatchPartialBatchDropRemainder(self): + return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) + + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) def testMapAndBatchSparse(self): @@ -411,7 +453,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testBatchAndMapDatasetFails(self): + def testMapAndBatchDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -425,7 +467,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testBatchAndMapDatasetShapeMismatch(self): + def testMapAndBatchDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -539,5 +581,73 @@ class PaddedBatchDatasetSerializationTest( lambda: build_dataset(seq_lens2), 8) +class RestructuredDatasetTest(test.TestCase): + + def test_assert_element_shape(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 94f800e8a58bc34eef3034cd976b931528c01940..6002cc73c8b41c2f20beaf0158af813807e58c90 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -104,6 +104,21 @@ class GroupByWindowTest(test.TestCase): self.assertAllEqual([0, 0, 0], sess.run(get_next)) self.assertAllEqual([1], sess.run(get_next)) + def testEmpty(self): + iterator = ( + dataset_ops.Dataset.range(4).apply( + grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Window size must be greater than zero, but got 0."): + print(sess.run(get_next)) + def testReduceFuncError(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -468,6 +483,31 @@ class BucketBySequenceLength(test.TestCase): self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual(sorted(boundaries), sorted(lengths_val)) + def testTupleElements(self): + + def elements_gen(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (x, y) + + def element_length_fn(x, y): + del y + return array_ops.shape(x)[0] + + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8])) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index dc3e38db59301bf1819999f479171af35930e9d2..4b5026067007e7ef0051f1647da1151be3a5631c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import threading from tensorflow.contrib.data.python.ops import prefetching_ops @@ -26,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -33,30 +33,34 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class StagingAreaOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test.TestCase): def setUp(self): self._event = threading.Event() - def _prefetch_fn_helper(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + def _create_ds_and_iterator(self, device0, initializable=False): def gen(): - for i in itertools.count(start=1, step=1): - yield [i + 0.0] + for i in range(1, 10): + yield [float(i)] if i == 6: self._event.set() with ops.device(device0): - dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() + ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + if initializable: + ds_iterator = ds.make_initializable_iterator() + else: + ds_iterator = ds.make_one_shot_iterator() + return (ds, ds_iterator) + + def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): + ds_iterator_handle = ds_iterator.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) + h, ds.output_types, ds.output_shapes) return remote_iterator.get_next() target = constant_op.constant(device0) @@ -64,15 +68,28 @@ class StagingAreaOpsTest(test.TestCase): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, target_device=target, - string_arg=iterator_3_handle, + string_arg=ds_iterator_handle, buffer_size=3, - thread_pool_size=2, shared_name=buffer_name) with ops.device(device1): prefetch_op = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=buffer_resource_handle, output_types=[dtypes.float32]) + reset_op = prefetching_ops.function_buffering_resource_reset( + function_buffer_resource=buffer_resource_handle) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + return (prefetch_op, reset_op, destroy_op) + + def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) + prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, + device0, device1) with self.test_session(config=worker_config) as sess: elem = sess.run(prefetch_op) @@ -86,26 +103,240 @@ class StagingAreaOpsTest(test.TestCase): self._event.wait() elem = sess.run(prefetch_op) self.assertEqual(elem, [5.0]) - sess.run( - resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True)) + sess.run(destroy_op) def testSameDeviceCPU(self): - self._prefetch_fn_helper("same_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:0") + self._prefetch_fn_helper_one_shot("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") def testDifferentDeviceCPU(self): - self._prefetch_fn_helper("diff_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:1") + self._prefetch_fn_helper_one_shot("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") def testDifferentDeviceCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") - self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/gpu:0") + self._prefetch_fn_helper_one_shot("cpu_gpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + def testReinitialization(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + # Lets reset the function buffering resource and reinitialize the + # iterator. Should be able to go through this again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testReinitializationOutOfRange(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + # Now reset everything and try it out again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + +class PrefetchToDeviceTest(test.TestCase): + + def testPrefetchToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_initializable_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 699e8e7865502facd05e0c4d6d4f01b80f7c050c..6ee1b572f121a9a40dfd638f7a858d5f1176ea3c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -35,9 +35,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -568,12 +566,20 @@ class MakeCsvDatasetTest(test.TestCase): dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string ] COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] + DEFAULT_VALS = [[], [], [], [], ["NULL"]] + DEFAULTS = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.float64), + constant_op.constant(["NULL"], dtype=dtypes.string) + ] LABEL = COLUMNS[0] def setUp(self): super(MakeCsvDatasetTest, self).setUp() self._num_files = 2 - self._num_records = 7 + self._num_records = 11 self._test_filenames = self._create_files() def _csv_values(self, fileno, recordno): @@ -588,49 +594,63 @@ class MakeCsvDatasetTest(test.TestCase): def _csv_record(self, fileno, recordno): return ",".join(str(v) for v in self._csv_values(fileno, recordno)) + def _create_file(self, fileno, header=True, comment=True): + fn = os.path.join(self.get_temp_dir(), "csv_file%d.csv" % fileno) + f = open(fn, "w") + if header: + f.write(",".join(self.COLUMNS) + "\n") + for recno in range(self._num_records): + f.write(self._csv_record(fileno, recno) + "\n") + if comment: + f.write("# Some comment goes here. Should be ignored!\n") + f.close() + return fn + def _create_files(self): filenames = [] for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "csv_file%d.csv" % i) - filenames.append(fn) - f = open(fn, "w") - f.write(",".join(self.COLUMNS) + "\n") # header line - for j in range(self._num_records): - f.write(self._csv_record(i, j) + "\n") - f.write("# Some comment goes here. Should be ignored!\n") - f.close() + filenames.append(self._create_file(i)) return filenames - def _make_csv_dataset(self, - filenames, - defaults, - label_key=LABEL, - batch_size=1, - num_epochs=1, - shuffle=False, - shuffle_seed=None): + def _make_csv_dataset( + self, + filenames, + defaults, + column_names=COLUMNS, + label_name=LABEL, + batch_size=1, + num_epochs=1, + shuffle=False, + shuffle_seed=None, + header=True, + comment="#", + na_value="", + default_float_type=dtypes.float32, + ): return readers.make_csv_dataset( filenames, - column_keys=self.COLUMNS, - column_defaults=defaults, - label_key=label_key, batch_size=batch_size, + column_names=column_names, + column_defaults=defaults, + label_name=label_name, num_epochs=num_epochs, shuffle=shuffle, shuffle_seed=shuffle_seed, - skip=1, - filter_fn= - lambda line: math_ops.not_equal(string_ops.substr(line, 0, 1), "#"), + header=header, + comment=comment, + na_value=na_value, + default_float_type=default_float_type, ) - def _next_actual_batch(self, file_indices, batch_size, num_epochs): + def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): features = {col: list() for col in self.COLUMNS} for _ in range(num_epochs): for i in file_indices: for j in range(self._num_records): values = self._csv_values(i, j) - if not values[-1]: - values[-1] = "NULL" # null values in csv are interpreted as default + for n, v in enumerate(values): + if v == "": # pylint: disable=g-explicit-bool-comparison + values[n] = defaults[n][0] values[-1] = values[-1].encode("utf-8") # Regroup lists by column instead of row @@ -651,7 +671,8 @@ class MakeCsvDatasetTest(test.TestCase): sess, dataset, file_indices, - label_key=LABEL, + defaults=tuple(DEFAULT_VALS), + label_name=LABEL, batch_size=1, num_epochs=1, ): @@ -659,11 +680,11 @@ class MakeCsvDatasetTest(test.TestCase): get_next = iterator.get_next() for expected_features in self._next_actual_batch(file_indices, batch_size, - num_epochs): + num_epochs, defaults): actual_features = sess.run(get_next) - if label_key is not None: - expected_labels = expected_features.pop(label_key) + if label_name is not None: + expected_labels = expected_features.pop(label_name) # Compare labels self.assertAllEqual(expected_labels, actual_features[1]) actual_features = actual_features[0] # Extract features dict from tuple @@ -676,10 +697,7 @@ class MakeCsvDatasetTest(test.TestCase): sess.run(get_next) def test_make_csv_dataset(self): - defaults = [ - constant_op.constant([], dtype=d) for d in self.COLUMN_TYPES[:-1] - ] - defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + defaults = self.DEFAULTS with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: @@ -705,11 +723,26 @@ class MakeCsvDatasetTest(test.TestCase): self._verify_records( sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + def test_make_csv_dataset_with_bad_columns(self): + """Tests that exception is raised when input is malformed. + """ + dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] + defaults = self.DEFAULTS + + # Duplicate column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, column_names=dupe_columns) + + # Label key not one of column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, label_name="not_a_real_label") + def test_make_csv_dataset_with_no_label(self): - defaults = [ - constant_op.constant([], dtype=d) for d in self.COLUMN_TYPES[:-1] - ] - defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + """Tests that CSV datasets can be created when no label is specified. + """ + defaults = self.DEFAULTS with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Read from both files. Make sure this works with no label key supplied. @@ -718,16 +751,64 @@ class MakeCsvDatasetTest(test.TestCase): defaults, batch_size=2, num_epochs=10, - label_key=None) + label_name=None) self._verify_records( sess, dataset, range(self._num_files), batch_size=2, num_epochs=10, - label_key=None) + label_name=None) + + def test_make_csv_dataset_with_no_comments(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), comment=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + comment=None, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def test_make_csv_dataset_with_no_header(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), header=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + header=False, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) def test_make_csv_dataset_with_types(self): + """Tests that defaults can be a dtype instead of a Tensor for required vals. + """ defaults = [d for d in self.COLUMN_TYPES[:-1]] defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) with ops.Graph().as_default() as g: @@ -735,10 +816,109 @@ class MakeCsvDatasetTest(test.TestCase): dataset = self._make_csv_dataset(self._test_filenames, defaults) self._verify_records(sess, dataset, range(self._num_files)) + def test_make_csv_dataset_with_no_col_names(self): + """Tests that datasets can be created when column names are not specified. + + In that case, we should infer the column names from the header lines. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + column_names=None, + batch_size=2, + num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def test_make_csv_dataset_type_inference(self): + """Tests that datasets can be created when no defaults are specified. + + In that case, we should infer the types from the first N records. + """ + # Test that it works with standard test files (with comments, header, etc) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + self._test_filenames, defaults=None, batch_size=2, num_epochs=10) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + defaults=[[], [], [], [], [""]]) + + # Test on a deliberately tricky file + fn = os.path.join(self.get_temp_dir(), "file.csv") + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, + dtypes.string, dtypes.string + ] + rows = [[0, 0, 0, "NAN", "", "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], + ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] + expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], + [123, 2, 2**64, 123.4, "", "cd,efg"]] + for row in expected: + row[-1] = row[-1].encode("utf-8") # py3 expects byte strings + row[-2] = row[-2].encode("utf-8") # py3 expects byte strings + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + with open(fn, "w") as f: + f.write(",".join(col_names)) + f.write("\n") + for row in rows: + f.write(",".join([str(v) if v else "" for v in row]) + "\n") + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + batch_size=1, + num_epochs=1, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float32, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + + # With float64 as default type for floats + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.string, dtypes.string + ] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + batch_size=1, + num_epochs=1, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float64, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + def test_make_csv_dataset_with_shuffle(self): total_records = self._num_files * self._num_records - defaults = [d for d in self.COLUMN_TYPES[:-1]] - defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + defaults = self.DEFAULTS for batch_size in [1, 2]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 3c7b46629edb13459766b5ef3f392e8d00ad4db8..5f47dcb33999119a690bd633f0c97a12a1ae1c84 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -21,7 +21,10 @@ import numpy as np from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -45,12 +48,10 @@ class ResampleTest(test.TestCase): target_dist=target_dist, initial_dist=initial_dist, class_func=lambda c, _: c, - seed=27)).make_initializable_iterator()) - init_op = iterator.initializer + seed=27)).make_one_shot_iterator()) get_next = iterator.get_next() with self.test_session() as sess: - sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -70,6 +71,43 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) + def testRandomClasses(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Apply a random mapping that preserves the data distribution. + def _remap_fn(_): + return math_ops.cast(random_ops.random_uniform([1]) * num_classes, + dtypes.int32)[0] + dataset = dataset.map(_remap_fn) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + classes, _ = zip(*returned) + bincount = np.bincount( + np.array(classes), + minlength=num_classes).astype(np.float32) / len(classes) + + self.assertAllClose(target_dist, bincount, atol=1e-2) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 36ddf3004237ed042f21d691d83eafbaa20621e6..b13ad9ba4e533e1bcef5161d983c8e6578d549b2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -47,6 +47,11 @@ class SequenceDatasetSerializationTest( # Skip nothing self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + def testInvalidSkip(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -69,6 +74,11 @@ class SequenceDatasetSerializationTest( # Take nothing self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + def testInvalidTake(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -100,6 +110,12 @@ class SequenceDatasetSerializationTest( # Test repeat empty dataset self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + def testInvalidRepeat(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), + None, 0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index c3331e963602d60fe27dd44b0cc06dfb20ca2b6a..a1a5c9ed05ff226086885e4e204875d3ca933590 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -72,14 +72,18 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", ], ) @@ -115,6 +119,7 @@ py_library( deps = [ ":contrib_op_loader", ":gen_dataset_ops", + "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -173,17 +178,9 @@ py_library( srcs = ["prefetching_ops.py"], deps = [ ":contrib_op_loader", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6eb512dec67cb7b9c8c4518d03aee0b436205f9a..1eba010b562a60ec9469f808fd657ca330a8f5d9 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -345,16 +346,61 @@ class _RestructuredDataset(dataset_ops.Dataset): return self._output_shapes +def assert_element_shape(expected_shapes): + """Assert the shape of this `Dataset`. + + ```python + shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) + print(result.output_shapes) # ==> "((16, 256), )" + ``` + + If dataset shapes and expected_shape, are fully defined, assert they match. + Otherwise, add assert op that will validate the shapes when tensors are + evaluated, and set shapes on tensors, respectively. + + Args: + expected_shapes: A nested structure of `tf.TensorShape` objects. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _check_shape(*elements): + flatten_tensors = nest.flatten(elements) + flatten_shapes = nest.flatten(expected_shapes) + checked_tensors = [with_shape(shape, tensor) + for shape, tensor in zip(flatten_shapes, + flatten_tensors)] + return nest.pack_sequence_as(elements, checked_tensors) + + def _apply_fn(dataset): + return _RestructuredDataset( + dataset.map(_check_shape), + dataset.output_types, + output_shapes=expected_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( + self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches = ops.convert_to_tensor( + self._num_parallel_batches_t = ops.convert_to_tensor( num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._drop_remainder_t = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + + self._batch_size = batch_size + self._drop_remainder = drop_remainder def _as_variant_tensor(self): # pylint: disable=protected-access @@ -363,8 +409,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): input_resource, self._map_func.captured_inputs, f=self._map_func, - batch_size=self._batch_size, - num_parallel_batches=self._num_parallel_batches, + batch_size=self._batch_size_t, + num_parallel_batches=self._num_parallel_batches_t, + drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( @@ -373,9 +420,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): @property def output_shapes(self): + dim = self._batch_size if self._drop_remainder else None return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(tensor_util.constant_value( - self._batch_size)).concatenate(s) + tensor_shape.vector(dim).concatenate(s) for s in nest.flatten(self._output_shapes) ]) @@ -384,7 +431,10 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): return self._output_types -def map_and_batch(map_func, batch_size, num_parallel_batches=1): +def map_and_batch(map_func, + batch_size, + num_parallel_batches=1, + drop_remainder=False): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -404,6 +454,9 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values can increase contention if CPU is scarce. + drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the + last batch should be dropped in case its size is smaller than desired; + the default behavior is not to drop the smaller batch. Returns: A `Dataset` transformation function, which can be passed to @@ -412,6 +465,6 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches) + num_parallel_batches, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 63226fe78163c59025623a362d17c400fbe57c67..6ef65f9624601286691505a795a86dd6226eead1 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import ops def Counter(start=0, step=1, dtype=dtypes.int64): - """Creates a `Dataset` of a `step`-separated count startin from `start`. + """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: @@ -38,12 +38,13 @@ def Counter(start=0, step=1, dtype=dtypes.int64): ``` Args: - start: starting value for count. - step: step size. - dtype: counter data type. + start: (Optional.) The starting value for the counter. Defaults to 0. + step: (Optional.) The step size for the counter. Defaults to 1. + dtype: (Optional.) The data type for counter elements. Defaults to + `tf.int64`. Returns: - A `Dataset` of scalar elements. + A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index a19be222545ef0242502ec07badbdae5c7634a0c..36591c055ae8f2c54981525ffcc3df128a990a61 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -42,7 +42,7 @@ def group_by_window(key_func, This transformation maps each consecutive element in a dataset to a key using `key_func` and groups the elements by key. It then applies `reduce_func` to at most `window_size_func(key)` elements matching the same - key. All execpt the final window for each key will contain + key. All except the final window for each key will contain `window_size_func(key)` elements; the final window may be smaller. You may provide either a constant `window_size` or a window size determined by @@ -140,9 +140,9 @@ def bucket_by_sequence_length(element_length_func, batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) - def element_to_bucket_id(element): + def element_to_bucket_id(*args): """Return int64 id of the length bucket for this element.""" - seq_length = element_length_func(element) + seq_length = element_length_func(*args) boundaries = list(bucket_boundaries) buckets_min = [np.iinfo(np.int32).min] + boundaries diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 7059b358f349e0ec847e85c37652012d48ed910a..77e23d0319e7f163f208c90bc0d5643520a4b466 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,8 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib # TODO(rohanj): Add a python class that constructs resource in the __init__ @@ -27,7 +37,6 @@ def function_buffering_resource(string_arg, target_device, f, buffer_size, - thread_pool_size=1, container="", shared_name=None, name=None): @@ -39,7 +48,6 @@ def function_buffering_resource(string_arg, shared_name=shared_name, f=f, buffer_size=buffer_size, - thread_pool_size=thread_pool_size, container=container, name=name) @@ -51,3 +59,189 @@ def function_buffering_resource_get_next(function_buffer_resource, function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) + + +def function_buffering_resource_reset(function_buffer_resource, name=None): + return gen_dataset_ops.function_buffering_resource_reset( + function_buffer_resource=function_buffer_resource, name=name) + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + device, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + + # Convert any `SparseTensorValue`s to `SparseTensor`s. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + ]) + + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) + return nest.flatten(ret) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=gen_dataset_ops.iterator_get_device( + self._input_iterator._iterator_resource), + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + + if not self._one_shot: + reset_op = function_buffering_resource_reset(self._buffering_resource) + with ops.control_dependencies([reset_op]): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + self._buffering_resource, + output_types=nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + nest.flatten(ret), nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + + return ret + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + self._input_dataset = input_dataset + self._device = device + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + device=self._device, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + device=self._device, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_device()` must be the last " + "transformation in a dataset pipeline.") + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_device(device, buffer_size=None): + """A transformation that prefetches dataset values to the given `device`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + device: A string. The name of a device to which elements will be prefetched. + buffer_size: (Optional.) The number of elements to buffer on `device`. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, device, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f70f9c881df168564cbf2431bbc2ebdf7e7f7ded..9a48aa02fba4813fc670364bda7f91c0ce091a45 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,6 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import csv +from math import ceil + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops @@ -26,8 +32,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -35,21 +44,145 @@ _ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.string) +def _is_valid_int32(str_val): + try: + # Checks equality to prevent int32 overflow + return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( + str_val) + except (ValueError, OverflowError): + return False + + +def _is_valid_int64(str_val): + try: + dtypes.int64.as_numpy_dtype(str_val) + return True + except (ValueError, OverflowError): + return False + + +def _is_valid_float(str_val, float_dtype): + try: + return float_dtype.as_numpy_dtype(str_val) < np.inf + except ValueError: + return False + + +def _infer_type(str_val, na_value, prev_type, float_dtype): + """Given a string, infers its tensor type. + + Infers the type of a value by picking the least 'permissive' type possible, + while still allowing the previous type inference for this column to be valid. + + Args: + str_val: String value to infer the type of. + na_value: Additional string to recognize as a NA/NaN CSV value. + prev_type: Type previously inferred based on values of this column that + we've seen up till now. + float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type + to parse float strings as. + Returns: + Inferred dtype. + """ + if str_val in ("", na_value): + return prev_type + + if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): + return dtypes.int32 + + if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, + dtypes.int64): + return dtypes.int64 + + if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: + return float_dtype + + return dtypes.string + + +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment): + for fn in filenames: + with file_io.FileIO(fn, "r") as f: + rdr = csv.reader( + f, + delimiter=field_delim, + quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) + if header: + next(rdr) # Skip header lines + + for csv_row in rdr: + if comment is not None and csv_row[0].startswith(comment): + continue # Skip comment lines + + if len(csv_row) != num_cols: + raise ValueError( + "Problem inferring types: CSV row has different number of fields " + "than expected.") + yield csv_row + + +def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, + na_value, header, comment, float_dtype, + rows_for_inference): + """Infers column types from the first N valid CSV records of files.""" + inferred_types = [None] * num_cols + + for rows_read, csv_row in enumerate( + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment)): + if rows_for_inference is not None and rows_read >= rows_for_inference: + break + for i, str_val in enumerate(csv_row): + inferred_types[i] = _infer_type(str_val, na_value, inferred_types[i], + float_dtype) + + # Replace None's with a default type + inferred_types = [t or dtypes.string for t in inferred_types] + # Default to 0 or '' for null values + return [ + constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) + for t in inferred_types + ] + + +def _infer_column_names(filenames, field_delim, use_quote_delim): + """Infers column names from first rows of files.""" + csv_kwargs = { + "delimiter": field_delim, + "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE + } + with file_io.FileIO(filenames[0], "r") as f: + column_names = next(csv.reader(f, **csv_kwargs)) + + for name in filenames[1:]: + with file_io.FileIO(name, "r") as f: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError("Files have different column names in the header row.") + return column_names + + def make_csv_dataset( file_pattern, batch_size, - column_keys, - column_defaults, - label_key=None, + column_names=None, + column_defaults=None, + label_name=None, field_delim=",", use_quote_delim=True, - skip=0, - filter_fn=None, + na_value="", + header=True, + comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=1, + num_parallel_reads=1, + num_parallel_parser_calls=2, + sloppy=False, + default_float_type=dtypes.float32, + num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -63,27 +196,36 @@ def make_csv_dataset( records. See @{tf.gfile.Glob} for pattern rules. batch_size: An int representing the number of consecutive elements of this dataset to combine in a single batch. - column_keys: A list of strings that corresponds to the CSV columns, in - order. One per column of the input record. - column_defaults: A list of default values for the CSV fields. One item per - column of the input record. Each item in the list is either one of the - following dtypes: float32, float64, int32, int64, or string, or a - `Tensor` with one of the aforementioned types. One item per column of - the input record, with either scalar default value for that column if it - is required, or, if the column is required, an empty `Tensor` or a dtype. - label_key: A optional string corresponding to the label column. If provided, - the data for this column is returned as a separate `Tensor` from the - features dictionary, so that the dataset complies with the format expected - by a `tf.Estimator.train` or `tf.Estimator.evaluate` input function. + column_names: An optional list of strings that corresponds to the CSV + columns, in order. One per column of the input record. If this is not + provided, infers the column names from the first row of the records. + These names will be the keys of the features dict of each dataset element. + column_defaults: A optional list of default values for the CSV fields. One + item per column of the input record. Each item in the list is either a + valid CSV dtype (float32, float64, int32, int64, or string), or a + `Tensor` with one of the aforementioned types. The tensor can either be + a scalar default value (if the column is optional), or an empty tensor (if + the column is required). If a dtype is provided instead of a tensor, the + column is also treated as required. If this list is not provided, tries + to infer types based on reading the first num_rows_for_inference rows of + files specified, and assumes all columns are optional, defaulting to `0` + for numeric values and `""` for string values. + label_name: A optional string corresponding to the label column. If + provided, the data for this column is returned as a separate `Tensor` from + the features dictionary, so that the dataset complies with the format + expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input + function. field_delim: An optional `string`. Defaults to `","`. Char delimiter to separate fields in a record. use_quote_delim: An optional bool. Defaults to `True`. If false, treats double quotation marks as regular characters inside of the string fields. - skip: An integer that corresponds to the number of lines to skip at the - head of each CSV file. Defaults to 0. - filter_fn: A callable function that takes in a CSV string and returns a - boolean that corresponds to whether the record should be included. If - None, does not filter records. + na_value: Additional string to recognize as NA/NaN. + header: A bool that indicates whether the first rows of provided CSV files + correspond to header lines with column names, and should not be included + in the data. + comment: An optional character string that marks lines that should not be + parsed as csv records. If this is provided, all lines that start with + this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. @@ -94,63 +236,124 @@ def make_csv_dataset( prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of batches consumed per training step. + num_parallel_reads: Number of threads used to read CSV records from files. + If >1, the results will be interleaved. + num_parallel_parser_calls: Number of parallel invocations of the CSV parsing + function on CSV records. + sloppy: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + default_float_type: Either `tf.float32` or `tf.float64`. If defaults are + not provided, float-like strings are interpreted to be this type. + num_rows_for_inference: Number of rows of a file to use for type inference + if record_defaults is not provided. If None, reads all the rows of all + the files. Defaults to 100. Returns: A dataset, where each element is a (features, labels) tuple that corresponds to a batch of `batch_size` CSV rows. The features dictionary maps feature column names to `Tensor`s containing the corresponding column data, and labels is a `Tensor` containing the column data for the label column - specified by `label_key`. + specified by `label_name`. + + Raises: + ValueError: If any of the arguments is malformed. """ + # Create dataset of all matching filenames filenames = _get_file_names(file_pattern, False) - column_defaults = [ - constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x - for x in column_defaults - ] - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) - if label_key is not None: - assert label_key in column_keys + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Clean arguments; figure out column names and defaults + if comment is not None and len(comment) != 1: + raise ValueError("`comment` arg must be a single-character string or None") + + if column_names is None: + if not header: + raise ValueError("Cannot infer column names without a header line.") + # If column names are not provided, infer from the header lines + column_names = _infer_column_names(filenames, field_delim, use_quote_delim) + if len(column_names) != len(set(column_names)): + raise ValueError("Cannot have duplicate column names.") + + if column_defaults is not None: + column_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in column_defaults + ] + else: + # If column defaults are not provided, infer from records at graph + # construction time + column_defaults = _infer_column_defaults( + filenames, len(column_names), field_delim, use_quote_delim, na_value, + header, comment, default_float_type, num_rows_for_inference) + + if label_name is not None and label_name not in column_names: + raise ValueError("`label_name` provided must be one of the columns.") + + # Define map and filter functions + def filter_fn(line): + return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) def filename_to_dataset(filename): ds = core_readers.TextLineDataset(filename) - if skip > 0: - ds = ds.skip(skip) - if filter_fn is not None: + if header: + ds = ds.skip(1) + if comment is not None: ds = ds.filter(filter_fn) return ds def decode_csv(line): - """Decodes csv line into features. + """Decodes CSV line into features. Args: line: String tensor corresponding to one csv record. Returns: A dictionary of feature names to values for that particular record. If - label_key is provided, extracts the label feature to be returned as the + label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ columns = parsing_ops.decode_csv( line, column_defaults, field_delim=field_delim, - use_quote_delim=use_quote_delim) - features = dict(zip(column_keys, columns)) - if label_key is not None: - label = features.pop(label_key) + use_quote_delim=use_quote_delim, + na_value=na_value, + ) + features = dict(zip(column_names, columns)) + if label_name is not None: + label = features.pop(label_name) return features, label return features - # TODO(rachelim): interleave records from files for better shuffling - dataset = dataset.flat_map(filename_to_dataset) - # TODO(rachelim): use fused shuffle_and_repeat for perf - if shuffle: + # Read files sequentially or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) + + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + dataset = dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) - if num_epochs != 1: + elif num_epochs != 1: dataset = dataset.repeat(num_epochs) - dataset = dataset.batch(batch_size) - dataset = dataset.map(decode_csv) + # Use map_and_batch for perf + # TODO(b/76425672): use num_parallel_calls for better performance tuning when + # that is added + dataset = dataset.apply( + batching.map_and_batch( + map_func=decode_csv, + batch_size=batch_size, + num_parallel_batches=int( + ceil(num_parallel_parser_calls / batch_size)))) + dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -246,12 +449,10 @@ def make_batched_features_dataset(file_pattern, `Tensor` or `SparseTensor` objects. """ # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) if shuffle: - dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True) - else: - # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic. - filenames = _get_file_names(file_pattern, shuffle) - dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + dataset = dataset.shuffle(len(filenames), shuffle_seed) # Read `Example` records from files as tensor objects. if reader_args is None: @@ -287,7 +488,7 @@ def make_batched_features_dataset(file_pattern, lambda x: parsing_ops.parse_example(x, features), num_parallel_calls=parser_num_threads) - # TODO(rachelim): Add an optional label_key argument for extracting the label + # TODO(rachelim): Add an optional label_name argument for extracting the label # from the features dictionary, to comply with the type expected by the # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function. dataset = dataset.prefetch(prefetch_buffer_size) diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 56f526a330bfbea7305b0754bfd114c5e97db506..b465397437adbdfaf865efb8ed2f80e57f48fcab 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -54,7 +54,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" dist_estimation_batch_size = 32 - target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") + target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") @@ -101,14 +101,16 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): initial_dist_ds)) .map(maybe_warn_on_large_rejection)) - current_probabilities_ds = dataset_ops.Dataset.zip( - (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) + def _gather_and_copy(class_val, acceptance_prob, data): + return (class_val, array_ops.gather(acceptance_prob, class_val), data) + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) filtered_ds = ( - dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, - dataset)) + current_probabilities_and_class_and_data_ds .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + return _apply_fn @@ -151,7 +153,7 @@ def _calculate_acceptance_probs(initial_probs, target_probs): ``` - A solution for a_i in terms of the other variabes is the following: + A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` """ # Add tiny to initial_probs to avoid divide by zero. diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index ae3847b8b62452b1afbe472fcb6369181ec60b73..3b50a48336d77ebd9327fa24e5612a95d5d0c372 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -13,14 +13,6 @@ load( "tf_pyclif_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD index 3dfbbf55273848afb8ad74ad444f0d85b45610bd..401527f1e74f7725d02a3b92a2c661d8ffc11e21 100644 --- a/tensorflow/contrib/deprecated/BUILD +++ b/tensorflow/contrib/deprecated/BUILD @@ -30,15 +30,3 @@ py_test( "//tensorflow/python:logging_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..74b2cd90a187159fd2da8ce236c14e813cc43c49 --- /dev/null +++ b/tensorflow/contrib/distribute/BUILD @@ -0,0 +1,36 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "distribute", + srcs = ["__init__.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/distribute/python:cross_tower_ops", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:step_fn", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28483f4c88504b1fb90f2afc927442018648fdca --- /dev/null +++ b/tensorflow/contrib/distribute/README.md @@ -0,0 +1,144 @@ +# Distribution Strategy + +> *NOTE*: This is a experimental feature. The API and performance +> characteristics are subject to change. + +## Overview + +[`DistributionStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy) +API is an easy way to distribute your training +across multiple devices/machines. Our goal is to allow users to use existing +models and training code with minimal changes to enable distributed training. +Moreover, we've design the API in such a way that it works with both eager and +graph execution. + +Currently we support one type of strategy, called +[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy). +It does in-graph replication with synchronous training +on many GPUs on one machine. Essentially, we create copies of all variables in +the model's layers on each device. We then use all-reduce to combine gradients +across the devices before applying them to the variables to keep them in sync. +In the future, we intend to support other kinds of training configurations such +as multi-node, synchronous, +[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program), +parameter servers and model parallelism. + +## Example + +Let's demonstrate how to use this API with a simple example. We will use the +[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) +approach, and show you how to scale your model to run on multiple GPUs on one +machine using `MirroredStrategy`. + +Let's consider a very simple model function which tries to learn a simple +function. + +```python +def model_fn(features, labels, mode): + layer = tf.layers.Dense(1) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + loss = tf.losses.mean_squared_error( + labels=labels, predictions=tf.reshape(logits, [])) + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss) + + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) +``` + +Let's also define a simple input function to feed data for training this model. +Note that we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. + + +```python +def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(100) + labels = tf.data.Dataset.from_tensors(1.).repeat(100) + return dataset_ops.Dataset.zip((features, labels)) +``` + +Now that we have a model function and input function defined, we can define the +estimator. To use `MirroredStrategy`, all we need to do is: + +* Create an instance of the `MirroredStrategy` class. +* Pass it to the +[`RunConfig`](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig) +parameter of `Estimator`. + + +```python +distribution = tf.contrib.distribute.MirroredStrategy() +config = tf.estimator.RunConfig(train_distribute=distribution) +classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) +classifier.train(input_fn=input_fn) +``` + +That's it! This change will now configure estimator to run on all GPUs on your +machine, with the `MirroredStrategy` approach. It will take care of distributing +the input dataset, replicating layers and variables on each device, and +combining and applying gradients. + +The model and input functions do not have to change because we have changed the +underlying components of TensorFlow (such as +optimizer, batch norm and summaries) to become distribution-aware. +That means those components know how to +combine their state across devices. Further, saving and checkpointing works +seamlessly, so you can save with one or no distribution strategy and resume with +another. + +Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__). +There are few things you can customize in practice: + +* You can specify a list of specific GPUs (using param `devices`) or the number +of GPUs (using param `num_gpus`), in case you don't want auto detection. +* You can specify various parameters for all reduce with the `cross_tower_ops` +param, such as the all reduce algorithm to use, and gradient repacking. + +## Performance Tips + +We've tried to make it such that you get the best performance for your existing +model. We also recommend you follow the tips from +[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance). +Specifically, we found using [`map_and_batch`](https://www.tensorflow.org/performance/datasets_performance#map_and_batch) +and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_performance#pipelining) +in the input function gives a solid boost in performance. When using +`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size. + +## Caveats +This feature is in early stages and there are a lot of improvements forthcoming: + +* Metrics are not yet supported during distributed training. +* Summaries are currently computed in every tower. +* Evaluation is not yet distributed. +* Eager support is in the works; performance can be more challenging with eager +execution. +* As mentioned earlier, multi-node and other distributed strategies will be +introduced in the future. +* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch) +your input data, we will place one batch on each GPU in each step. So your +effective batch size will be `num_gpus * batch_size`. Therefore, consider +adjusting your learning rate or batch size according to the number of GPUs. +We are working on addressing this limitation by splitting each batch across GPUs +instead. +* Dictionaries inside dataset in the input are not supported when prefetching +on GPUs is turned on. (If you need to use dictionaries in the dataset, turn off +prefetching on GPUs by passing param `prefetch_on_device=False` to +`MirroredStrategy`) +* PartitionedVariables are not supported yet. + +## What's next? + +Please give distribution strategies a try. This feature is in early stages and +is evolving, so we welcome your feedback via +[issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new). + + diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76711baf3a11c8978fbb5770ec173ff74a153158 --- /dev/null +++ b/tensorflow/contrib/distribute/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Prototype of a distributed computation library for TF.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.cross_tower_ops import * +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.python.training.distribute import * + +from tensorflow.python.util.all_util import remove_undocumented + + +_allowed_symbols = [ + 'AllReduceCrossTowerOps', + 'CrossTowerOps', + 'DistributionStrategy', + 'MirroredStrategy', + 'Monitor', + 'OneDeviceStrategy', + 'ReductionToOneDeviceCrossTowerOps', + 'Step', + 'StandardInputStep', + 'StandardSingleLossStep', + 'TowerContext', + 'get_cross_tower_context', + 'get_distribution_strategy', + 'get_loss_reduction', + 'get_tower_context', + 'has_distribution_strategy', + 'require_tower_context', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..78b2b0054aa95701ad192b4fb9a0727ce287de4b --- /dev/null +++ b/tensorflow/contrib/distribute/python/BUILD @@ -0,0 +1,444 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# TODO(priyag): Figure out testonly issues that are preventing us from +# including our tests in pip for now. + +py_library( + name = "values", + srcs = ["values.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":prefetching_ops_v2", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:checkpointable", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "values_test", + srcs = ["values_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:errors", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":shared_variable_creator", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device", + "//tensorflow/python:framework_ops", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "@six_archive//:six", + ], +) + +py_library( + name = "one_device_strategy", + srcs = ["one_device_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":values", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "strategy_test_lib", + testonly = 1, + srcs = ["strategy_test_lib.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "combinations", + testonly = 1, + srcs = ["combinations.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":one_device_strategy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "combinations_test", + srcs = ["combinations_test.py"], + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "mirrored_strategy_test", + srcs = ["mirrored_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":one_device_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:test", + ], +) + +cuda_py_test( + name = "mirrored_strategy_multigpu_test", + srcs = ["mirrored_strategy_multigpu_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "guitar", + "no_pip", + "multi_and_single_gpu", + # Do not perform the extra analysis on this test, because it is already + # performed for the `:mirrored_strategy_test` target. + "no_oss", + "noasan", + "notap", + "notsan", + ], +) + +py_library( + name = "step_fn", + srcs = ["step_fn.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:training", + "//tensorflow/python/eager:backprop", + ], +) + +cuda_py_test( + name = "minimize_loss_test", + srcs = ["minimize_loss_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/ops/losses", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "optimizer_v2_test", + srcs = ["optimizer_v2_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "estimator_integration_test", + srcs = ["estimator_integration_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:dnn_linear_combined", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "single_loss_example", + srcs = ["single_loss_example.py"], + deps = [ + ":step_fn", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":single_loss_example", + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "monitor", + srcs = ["monitor.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "monitor_test", + srcs = ["monitor_test.py"], + additional_deps = [ + ":combinations", + ":monitor", + ":one_device_strategy", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "shared_variable_creator", + srcs = ["shared_variable_creator.py"], + visibility = ["//tensorflow:internal"], +) + +py_test( + name = "shared_variable_creator_test", + srcs = ["shared_variable_creator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":shared_variable_creator", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "cross_tower_utils", + srcs = ["cross_tower_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/nccl:nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "cross_tower_ops", + srcs = ["cross_tower_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":cross_tower_utils", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_test( + name = "cross_tower_ops_test", + srcs = ["cross_tower_ops_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":cross_tower_ops", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py new file mode 100644 index 0000000000000000000000000000000000000000..02b1e7ef9fcd4767c59898bd343e712e285e67d5 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -0,0 +1,297 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Facilities for creating multiple test combinations. + +Here is an example of testing various optimizers in Eager and Graph mode: + +class AdditionExample(test.TestCase, parameterized.TestCase): + @combinations.generate( + combinations.combine(mode=["graph", "eager"], + optimizer=[AdamOptimizer(), + GradientDescentOptimizer()])) + def testOptimizer(self, optimizer): + ... f(optimizer)... + +This will run `testOptimizer` 4 times with the specified optimizers: 2 in +Eager and 2 in Graph mode. +The test will be provided with arguments that match the arguments of combine +by name. It is necessary to request all arguments, except for `mode`, which is +optional. + +`combine()` function is available for creating a cross product of various +options. `times()` function exists for creating a product of N `combine()`-ed +results. See below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import sys +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.optimizer_v2 import adam as adam_v2 +from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.util import tf_inspect + + +GPU_TEST = "test_gpu" in sys.argv[0] + + +def generate(combinations): + """A decorator for generating test cases of a test method or a test class. + + Args: + combinations: a list of dictionaries created using combine() and times(). + + Restrictions: + -- there should always be a "mode" argument. Accepted values are "eager" + and "graph". + -- arguments of the test method must match by name to get the corresponding + value of the combination. Tests must accept all arguments (except "mode", + which is optional). + -- distribution argument is special. It is meant for passing instances of + DistributionStrategy. Each instance is to be passed as `(, + )` tuple, where is the number of required + GPUs. If the required number of GPUs for the DistributionStrategy isn't + available then the test case is going to be skipped. + + Returns: + a decorator that will cause the test method to be run under the specified + conditions. + + Raises: + ValueError - if "mode" argument wasn't either "eager" or "graph. + """ + + def decorator(test_function): + """The decorator to be returned.""" + + # Generate good test names that can be used with --test_filter. + for combination in combinations: + # We use OrderedDicts in `combine()` and `times()` to ensure stable + # order of keys in each dictionary. + assert isinstance(combination, OrderedDict) + name = "".join([ + "_{}_{}".format( + "".join(filter(str.isalnum, key)), + "".join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + combination.update({"testcase_name": "_test{}".format(name)}) + + @parameterized.named_parameters(*combinations) + def decorated(self, **kwargs): + """A wrapped test method that sets up `test_function`.""" + assert "mode" in kwargs + mode = kwargs["mode"] + + if "distribution" in kwargs: + distribution = kwargs["distribution"] + kwargs["distribution"] = distribution.strategy + if not distribution.required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < distribution.required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(distribution.required_gpus, context.num_gpus())) + + requested_arguments = tf_inspect.getfullargspec(test_function).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with context.eager_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + elif mode == "graph": + with context.graph_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + + return decorated + return decorator + + +def combine(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = combine(**rest) + + key = first[0] + values = first[1] + + return [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + + +def times(*combined): + """Generate a product of N sets of combinations. + + times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) + + Args: + *combined: N lists of dictionaries that specify combinations. + + Returns: + a list of dictionaries for each combination. + + Raises: + ValueError: if some of the inputs have overlapping keys. + """ + assert combined + + if len(combined) == 1: + return combined[0] + + first = combined[0] + rest_combined = times(*combined[1:]) + + combined_results = [] + for a in first: + for b in rest_combined: + if set(a.keys()).intersection(set(b.keys())): + raise ValueError("Keys need to not overlap: {} vs {}".format( + a.keys(), b.keys())) + + combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) + return combined_results + + +class NamedObject(object): + """A class that translates an object into a good test name.""" + + def __init__(self, name, obj): + self._name = name + self._obj = obj + + def __getattr__(self, name): + return getattr(self._obj, name) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __repr__(self): + return self._name + + +class NamedDistribution(object): + """Translates DistributionStrategy and its data into a good name.""" + + def __init__(self, name, distribution, required_gpus): + self._distribution = distribution + self._name = name + self._required_gpus = required_gpus + + def __repr__(self): + return self._name + + @property + def strategy(self): + return self._distribution + + @property + def required_gpus(self): + return self._required_gpus + + +one_device_strategy = NamedDistribution( + "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + None) +mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "MirroredCPUAndGPU", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) +mirrored_strategy_without_prefetch = NamedDistribution( + "MirroredCPUAndGPUNoPrefetch", + mirrored_strategy.MirroredStrategy( + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) +mirrored_strategy_with_two_gpus = NamedDistribution( + "Mirrored2GPUs", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + +adam_optimizer_v1_fn = NamedObject( + "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v1_fn = NamedObject( + "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) + +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v2_fn = NamedObject( + "GradientDescentV2", + lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) + +graph_and_eager_modes = ["graph", "eager"] + + +def distributions_and_v1_optimizers(): + """A common set of combination with DistributionStrategies and Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + + +def distributions_and_v2_optimizers(): + """DistributionStrategies and V2 Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py new file mode 100644 index 0000000000000000000000000000000000000000..219b24160f3902fcfa5363cc39a8fc5b30d00308 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for some testing utils from strategy_test_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test + + +class TestingCombinationsTest(test.TestCase): + + def test_combine(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 1, + "b": 3 + }, { + "a": 2, + "b": 2 + }, { + "a": 2, + "b": 3 + }], combinations.combine(a=[1, 2], b=[2, 3])) + + def test_add(self): + self.assertEqual( + [{ + "a": 1 + }, { + "a": 2 + }, { + "b": 2 + }, { + "b": 3 + }], + combinations.combine(a=[1, 2]) + + combinations.combine(b=[2, 3])) + + def test_times(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1 + c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "eager")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "eager")]) + ], c4) + + def test_times_variable_arguments(self): + c1 = combinations.combine(mode=["graph", "eager"]) + c2 = combinations.combine(optimizer=["adam", "gd"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1, c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "gd")]) + ], c4) + self.assertEqual( + combinations.combine( + mode=["graph", "eager"], + optimizer=["adam", "gd"], + distribution=["d1", "d2"]), c4) + + def test_overlapping_keys(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"): + _ = combinations.times(c1, c2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe5e877d59518056db3fea251cdae0ed854d0e4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -0,0 +1,585 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes for different algorithms of reduction and broadcasting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.client import device_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import device_util + + +def _validate_destinations(destinations): + if not isinstance(destinations, + (value_lib.DistributedValues, six.string_types, list)): + raise ValueError("destinations must be one of a `DistributedValues` object," + " a device string, a list of device strings or None") + + if not destinations: + raise ValueError("destinations can not be empty") + + +def _validate_value_destination_pairs(value_destination_pairs): + # pylint: disable=g-missing-docstring + if not value_destination_pairs: return False + if not isinstance(value_destination_pairs, (list, tuple)): return False + if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): + return False + if not all([isinstance(v[0], value_lib.PerDevice) + for v in value_destination_pairs]): + return False + return True + + +def _get_devices_from(destinations): + if isinstance(destinations, value_lib.DistributedValues): + return list(destinations.devices) + elif isinstance(destinations, six.string_types): + return [device_util.canonicalize(destinations)] + else: + return [ + device_util.canonicalize(destination) for destination in destinations + ] + + +def _devices_match(left, right): + return set(_get_devices_from(left)) == set(_get_devices_from(right)) + + +def _all_devices_match(value_destination_pairs): + if not all([d is None or _devices_match(v, d) + for v, d in value_destination_pairs]): + return False + if not all([_devices_match(v, value_destination_pairs[0][0]) + for v, _ in value_destination_pairs[1:]]): + return False + return True + + +def _simple_broadcast(tensor, destinations): + index = {} + devices = _get_devices_from(destinations) + for d in devices: + with ops.device(d): + index[d] = array_ops.identity(tensor) + return value_lib.Mirrored(index) + + +def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, + method_string): + # pylint: disable=g-missing-docstring + all_values = [] + count = 0 + for v in per_device_value._index.values(): # pylint: disable=protected-access + if isinstance(v, value_lib.MapOutput): + v_list = v.get() + if not v_list: + continue + count += len(v_list) + # Sum within each device before aggregating across devices. + v = math_ops.add_n(v_list) + else: + count += 1 + all_values.append(v) + if not all_values: + raise ValueError("`per_device_value` must be non-empty") + + with ops.device(reduce_to_device): + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + if method_string == "sum": + reduced = accumulation_fn(all_values) + elif method_string == "mean": + reduced = accumulation_fn(all_values) / count + else: + raise ValueError("`method_string` must be 'sum' or 'mean'") + return reduced + + +class CrossTowerOps(object): + """Base class for cross-tower reduction and broadcasting algorithms.""" + + def __init__(self): + pass + + def reduce(self, method_string, per_device_value, destinations=None): + """Reduce `per_device_value` to `destinations`. + + It runs the reduction operation defined by `method_string` and put the + result on `destinations`. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + per_device_value: a PerDevice object. + destinations: the reduction destinations. + + Returns: + a Mirrored object. + + Raises: + ValueError: if per_device_value is not a PerDevice object. + """ + if not isinstance(per_device_value, value_lib.PerDevice): + raise ValueError("`per_device_value` must be a `PerDevice` object.") + if destinations is not None: + _validate_destinations(destinations) + return self._reduce(method_string, per_device_value, destinations) + + def batch_reduce(self, method_string, value_destination_pairs): + """Reduce PerDevice objects in a batch. + + Reduce each first element in `value_destination_pairs` to each second + element which indicates the destinations. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + value_destination_pairs: a list or a tuple of tuples of PerDevice objects + and destinations. If a destination is None, then the destinations + are set to match the devices of the input PerDevice object. + + Returns: + a list of Mirrored objects. + + Raises: + ValueError: if `value_destination_pairs` is not a list or a tuple of + tuples of PerDevice objects and destinations + """ + if not _validate_value_destination_pairs(value_destination_pairs): + raise ValueError("`value_destination_pairs` must be a list or a tuple of " + "tuples of PerDevice objects and destinations") + for _, d in value_destination_pairs: + if d is not None: + _validate_destinations(d) + + return self._batch_reduce(method_string, value_destination_pairs) + + def broadcast(self, tensor, destinations): + """Broadcast the `tensor` to destinations. + + Args: + tensor: the tensor to broadcast. + destinations: the broadcast destinations. + + Returns: + a Mirrored object. + """ + _validate_destinations(destinations) + return self._broadcast(tensor, destinations) + + def _reduce(self, method_string, per_device_value, destinations): + raise NotImplementedError( + "_reduce method must be implemented in descendants.") + + def _batch_reduce(self, method_string, value_destination_pairs): + raise NotImplementedError( + "_batch_reduce method must be implemented in descendants.") + + def _broadcast(self, tensor, destinations): + return _simple_broadcast(tensor, destinations) + + +class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): + """Always do reduction to one device first and then do broadcasting. + + Batch reduction is done by reduction on each element one by one. + """ + + def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n): + """Constructor. + + Args: + reduce_to_device: the intermediate device to reduce to. If None, reduce + to the first device in `destinations` of the reduce() method. + accumulation_fn: a function that does accumulation. + """ + self.reduce_to_device = reduce_to_device + self.accumulation_fn = accumulation_fn + super(ReductionToOneDeviceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = self.reduce_to_device or devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + self.accumulation_fn, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + return [self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs] + + +def _group_value_by_device(per_device_values): + """Group values into sublists by their devices. + + This grouping is needed to call the all-reduce library. + + Args: + per_device_values: a list of PerDevice obejcts. + + Returns: + a list of lists, each sublist has components for its corresponding device of + PerDevice objects, paired with a None. + """ + destinations = per_device_values[0].devices + grouped = [[] for _ in range(len(destinations))] + for per_device_value in per_device_values: + # pylint: disable=protected-access + for i, v in enumerate(per_device_value._index.values()): + assert per_device_value.devices == destinations + grouped[i].append((v, None)) + return grouped + + +def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): + """Ungroup results from all-reduce and make Mirrored objects. + + Each all-reduce result will be divided by the number of destinations before + Mirrored objects are created if method_string is "mean". + + Args: + grouped_reduced: a list of lists, each sublist has components for each + device, paired with a None. It is the result from + cross_tower_utils.aggregate_gradients_using*. + destinations: a list of device strings for returned Mirrored objects. + method_string: "mean" or "sum". + + Returns: + a list of Mirrored objects. + """ + index = [{} for _ in range(len(grouped_reduced[0]))] + for d, per_device_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_device_reduced): + if method_string == "mean": + index[i][destinations[d]] = v / len(destinations) + else: + index[i][destinations[d]] = v + return [value_lib.Mirrored(v) for v in index] + + +class ConcatAndSplitPacker(object): + """Concatenate and split tensors for reduction.""" + + def __init__(self, num_packs=1): + """Initialize the ConcatAndSplitPacker object. + + Args: + num_packs: specifies the number of split packs that will be + formed. + + Raises: + ValueError: if num_packs is not greater than 0. + """ + if num_packs <= 0: + raise ValueError("num_packs must be greater than zero.") + self.num_packs = num_packs + + def pack(self, grouped_grads_and_vars): + """Pack tensors.""" + self.grouped_grads_and_vars = grouped_grads_and_vars + self.all_tower_shapes = [] + self.all_tower_sizes = [] + + device_grad_packs = [] + for tower_grads_and_vars in grouped_grads_and_vars: + with ops.colocate_with(tower_grads_and_vars[0][0]): + # Flatten all the grads. + flat_grads = [ + array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars + ] + # Remember the original shape of all the grads. + tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars] + # Remember the original sizes of all the grads. + tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars] + # Concat all the flat grads into a big flat tensor. + concat_grads = array_ops.concat(flat_grads, 0) + + # Split the big tensor into num_splits packs. In cases where the + # total size is not divisible num_splits, the last pack gets + # more elements. + # TODO(zhengxq): it is also possible to optimize away all the concat + # as well. + num_splits = self.num_packs + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits + split_size_last = total_grad_size - split_size * (num_splits - 1) + split_sizes = [split_size] * (num_splits - 1) + [split_size_last] + grad_packs = array_ops.split(concat_grads, split_sizes) + + # Ready to aggregate the repacked gradients, with fake variables. + # TODO(zhengxq): It is hacky to have to use fake variables. + # We should remove the need for variables in + # aggregate_gradients_using*. + device_grad_packs.append(zip(grad_packs, [None] * num_splits)) + self.all_tower_shapes.append(tower_shapes) + self.all_tower_sizes.append(tower_sizes) + + return device_grad_packs + + def unpack(self, summed_device_grad_packs): + """Reverse the pack.""" + aggregated_device_grads = [] + for (summed_tower_grad_packs, + tower_grads_and_vars, tower_shapes, tower_sizes) in zip( + summed_device_grad_packs, self.grouped_grads_and_vars, + self.all_tower_shapes, self.all_tower_sizes): + # pylint: enable=line-too-long + # Reverse the packing operations in the previous steps. Form the + # summed gradients back into their original shapes. + with ops.colocate_with(summed_tower_grad_packs[0][0]): + # Form a list of the summed grad packs. + device_grad_packs = [g for g, _ in summed_tower_grad_packs] + + # Concat them back into a big flat tensor. + device_grads_concat = array_ops.concat(device_grad_packs, 0) + + # Split the tensors back into their original sizes. + grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes) + + # Reshape the tensors back into their original shapes. + grads_with_shapes = [ + array_ops.reshape(grad, shape) + for shape, grad in zip(tower_shapes, grads_with_sizes) + ] + + # Form the list with the original list of variables. + summed_tower_grads = [ + (g, v) for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars) + ] + aggregated_device_grads.append(summed_tower_grads) + return aggregated_device_grads + + +class AggregateSmallTensorPacker(object): + """Concatenate small gradient tensors together for reduction.""" + + def __init__(self, + agg_small_grads_max_bytes=1048576, + agg_small_grads_max_group=16): + """Initialize the AggregateSmallTensorPacker object. + + Args: + agg_small_grads_max_bytes: largest tensor eligible for aggregation, + in number of bytes. + agg_small_grads_max_group: largest permitted aggregation of small + tensors. + + Raises: + ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group` + is not greater than 0. + """ + if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0: + raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group" + " should both be greater than zero.") + self.agg_small_grads_max_bytes = agg_small_grads_max_bytes + self.agg_small_grads_max_group = agg_small_grads_max_group + + def pack(self, grouped_grads_and_vars): + """Aggregate small tensors.""" + if (self.agg_small_grads_max_bytes > 0 and + self.agg_small_grads_max_group > 0): + tower_grads, self.packing = cross_tower_utils.pack_small_tensors( + grouped_grads_and_vars, + max_bytes=self.agg_small_grads_max_bytes, + max_group=self.agg_small_grads_max_group) + return tower_grads + + def unpack(self, summed_device_grad_packs): + """Reverse the aggregation process.""" + return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs, + self.packing) + + +class AllReduceCrossTowerOps(CrossTowerOps): + """Reduction using all reduce.""" + + def __init__(self, + all_reduce_alg="nccl", + num_packs=1, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """All-reduce implementation of CrossTowerOps. + + Before performing all-reduce, tensors will be repacked or aggregated for + more efficient cross-device transportation: + 1) If `num_packs` is non-zero, pack values into + `num_packs` splits. + 2) Otherwise, if `agg_small_grads_max_bytes` > 0 and + `agg_small_grads_max_group` > 0, aggregate values smaller than + `agg_small_grads_max_bytes` into groups with at most + `agg_small_grads_max_group` values. + 3) Otherwise, no repacking or grouping will happen. + + Args: + all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or + "hierarchical_copy" are supported. + num_packs: see above. + agg_small_grads_max_bytes: see above. + agg_small_grads_max_group: see above. + tensors. + """ + self.all_reduce_alg = all_reduce_alg + self.num_packs = num_packs + self.agg_small_grads_max_bytes = agg_small_grads_max_bytes + self.agg_small_grads_max_group = agg_small_grads_max_group + super(AllReduceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + if ((destinations is None or _devices_match(per_device_value, destinations)) + and not context.executing_eagerly()): + return self._batch_all_reduce(method_string, [per_device_value])[0] + else: + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + math_ops.add_n, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + if (_all_devices_match(value_destination_pairs) and + not context.executing_eagerly()): + return self._batch_all_reduce(method_string, + [v[0] for v in value_destination_pairs]) + else: + if not context.executing_eagerly(): + logging.warning("Efficient batch_reduce is not supported if " + "destinations are different.") + return [ + self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + destinations = per_device_values[0].devices + grouped = _group_value_by_device(per_device_values) + if self.num_packs > 0: + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s and num_packs = %d", len(per_device_values), + self.all_reduce_alg, self.num_packs) + tensor_packer = ConcatAndSplitPacker(self.num_packs) + device_grad_packs = tensor_packer.pack(grouped) + elif (self.agg_small_grads_max_bytes > 0 and + self.agg_small_grads_max_group > 0): + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self.all_reduce_alg, self.agg_small_grads_max_bytes, + self.agg_small_grads_max_group) + tensor_packer = AggregateSmallTensorPacker(100, 10) + device_grad_packs = tensor_packer.pack(grouped) + else: + logging.info( + "batch_all_reduce invoked for batches size = %d with algorithm = %s", + len(per_device_values), self.all_reduce_alg) + tensor_packer = None + device_grad_packs = grouped + + # The actual aggregation of the repacked gradients. Note that they are + # sharded among different aggregation trees. So it is important to strike + # the balance on num_splits. + if self.all_reduce_alg == "nccl": + reduced = cross_tower_utils.aggregate_gradients_using_nccl( + device_grad_packs) + else: + # TODO(yuefengz): check that gpu ids in `destinations` are in ascending + # order. + reduced = ( + cross_tower_utils.aggregate_gradients_using_hierarchical_copy( + destinations, device_grad_packs)) + + if tensor_packer: + reduced = tensor_packer.unpack(reduced) + + return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + method_string) + + +_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + + +def _has_dgx1_like_links(gpu_links): + if not gpu_links: + return False + # TODO(yuefengz): figure out the right topology for hierarchial copy if + # number of gpus are less than 8. + if len(gpu_links) < 8: + return False + for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)): + if (set(gpu_link) != set(dgx1_link) and + set(gpu_link) != set(dgx1_link + [i])): + return False + return True + + +def _choose_all_reduce_algorithm(device_links): + if _has_dgx1_like_links(device_links): + logging.info("Configured hierarchical_copy with num_packs=%d", + len(device_links)) + return AllReduceCrossTowerOps( + "hierarchical_copy", num_packs=len(device_links)) + else: + logging.info("Configured nccl all-reduce.") + return AllReduceCrossTowerOps("nccl", num_packs=1) + + +def choose_the_best(devices, session_config=None): + """Find the best subclass of CrossTowerOps given a tensorflow session. + + Args: + devices: a list of devices passed for distribute strategy. + session_config: a tensorflow session config or None. If None, it will make + deciesion based on all local devices. + + Returns: + a subclass of CrossTowerOps. + """ + requested_devices = set([device_util.canonicalize(d) for d in devices]) + machine_devices = device_lib.list_local_devices(session_config=session_config) + using_devices = [] + for d in machine_devices: + if device_util.canonicalize(d.name) in requested_devices: + using_devices.append(d) + else: + logging.info( + "Device is available but not used by distribute strategy: %s", d.name) + + if len(using_devices) != len(requested_devices): + logging.warning("Not all devices in distribute strategy are visible by " + "TensorFlow sessions.") + return ReductionToOneDeviceCrossTowerOps() + + if any([d.device_type.lower() != "gpu" for d in using_devices]): + logging.warning("Not all devices in DistributionStrategy are visible to " + "TensorFlow session.") + return ReductionToOneDeviceCrossTowerOps() + + device_links = [[] for _ in range(len(using_devices))] + for i, device in enumerate(using_devices): + for link in device.locality.links.link: + device_links[i].append(link.device_id) + + return _choose_all_reduce_algorithm(device_links) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7b0870887465ec2fe40007695d099277db38bf --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -0,0 +1,221 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CrossTowerOps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _make_per_device(values, devices): + devices = cross_tower_ops_lib._get_devices_from(devices) + assert len(values) == len(devices) + index = {} + for d, v in zip(devices, values): + with ops.device(d): + placed_v = array_ops.identity(v) + index[d] = placed_v + return value_lib.PerDevice(index) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _fake_mirrored(value, devices): + """Create a faked Mirrored object for testing. + + All components of the returned Mirrored have the same objects, which is not + true in reality. + """ + devices = cross_tower_ops_lib._get_devices_from(devices) + return value_lib.Mirrored( + {d: v for d, v in zip(devices, [value] * len(devices))}) + + +_cpu_device = "/device:CPU:0" + + +class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): + + def _assert_value_equal(self, left, right): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_value_equal(l, r) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(left.devices, right.devices) + if context.executing_eagerly(): + self.assertEqual([v.numpy() for v in left._index.values()], + list(right._index.values())) + else: + with self.test_session() as sess: + self.assertEqual( + sess.run(list(left._index.values())), list(right._index.values())) + + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + devices = distribution.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_value_equal( + cross_tower_ops.reduce("mean", per_device, destinations=destinations), + _fake_mirrored(mean, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "mean", per_device_2, destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce("sum", per_device, destinations=destinations), + _fake_mirrored(mean * len(devices), destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "sum", per_device_2, destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations or per_device)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "mean", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2)]) + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "sum", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + + # test broadcast() + for destinations in all_destinations: + if destinations is None: + continue + else: + self._assert_value_equal( + cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + def testChooseAlgorithm(self): + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "hierarchical_copy") + self.assertEqual(result.num_packs, 8) + + # if there are only 4 devices + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "nccl") + self.assertEqual(result.num_packs, 1) + + # if devices links contain each device itself + device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], + [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], + [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "hierarchical_copy") + self.assertEqual(result.num_packs, 8) + + # if not dgx1-like links + device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], + [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "nccl") + self.assertEqual(result.num_packs, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc04e2195f6d305e0f7c642f24c355286f1a8cfa --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -0,0 +1,339 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for cross_tower_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as pycoll + +from tensorflow.contrib import nccl +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def aggregate_gradients_using_nccl(tower_grads): + """Aggregate gradients using nccl allreduce.""" + agg_all_g_and_v = [] + for single_g_and_v in zip(*tower_grads): + single_grads = [g for g, _ in single_g_and_v] + agg_grads = nccl.all_sum(single_grads) + agg_all_g_and_v.append( + [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) + + agg_all_g_and_v = list(zip(*agg_all_g_and_v)) + + return agg_all_g_and_v + + +def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads): + """Aggregate gradients using hierarchical copies. + + Args: + avail_devices: available GPU devices. + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + + Returns: + The list of (aggregated_gradient, variable), where the gradient has been + summed across all towers and the variable is chosen from the first tower. + """ + # This only works for DGX-1 type of machine topology + # Device peer to peer matrix + # DMA: 0 1 2 3 4 5 6 7 + # 0: Y Y Y Y Y N N N + # 1: Y Y Y Y N Y N N + # 2: Y Y Y Y N N Y N + # 3: Y Y Y Y N N N Y + # 4: Y N N N Y Y Y Y + # 5: N Y N N Y Y Y Y + # 6: N N Y N Y Y Y Y + # 7: N N N Y Y Y Y Y + agg_grads = [] + num_devices = len(avail_devices) + # In the special case of DGX-1 machine topology, the two groups have equal + # size. + group_size = num_devices // 2 + for i, single_grads in enumerate(zip(*tower_grads)): + group_0_main_device = i % num_devices + group_1_main_device = (group_0_main_device + group_size) % num_devices + if group_0_main_device < group_size: + group_0_begin = 0 + group_1_begin = group_size + else: + group_0_begin = group_size + group_1_begin = 0 + + # Aggregate the first group. + group_0_device_grads = single_grads[group_0_begin: + group_0_begin + group_size] + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads, _ = aggregate_single_gradient_using_copy( + group_0_device_grads, False, False) + + # Aggregate the second group. + group_1_device_grads = single_grads[group_1_begin: + group_1_begin + group_size] + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads, _ = aggregate_single_gradient_using_copy( + group_1_device_grads, False, False) + + # Aggregate between the groups. + with ops.device(avail_devices[group_0_main_device]): + (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( + [group_0_agg_grads, group_1_agg_grads], False, False) + + # Broadcast the result back into the root of each group. + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) + + agg_grads_bcast = [] + for j in range(len(single_grads)): + with ops.device(avail_devices[j]): + # Broadcast the result back to each member in the group from the root. + if (group_0_main_device < group_size) == (j < group_size): + src_device_grad = group_0_agg_grads_bcast + else: + src_device_grad = group_1_agg_grads_bcast + agg_grads_bcast.append(array_ops.identity(src_device_grad)) + + agg_grads.append( + [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) + + agg_grads = list(zip(*agg_grads)) + + return agg_grads + + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = math_ops.add_n(grads) + + if use_mean and len(grads) > 1: + grad = array_ops.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = array_ops.logical_not( + array_ops.reduce_all(array_ops.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None + + +def extract_ranges(index_list, range_size_limit=32): + """Extract consecutive ranges and singles from index_list. + + Args: + index_list: List of monotone increasing non-negative integers. + range_size_limit: Largest size range to return. If a larger + consecutive range exists, it will be returned as multiple + ranges. + + Returns: + (ranges, singles) where ranges is a list of [first, last] pairs of + consecutive elements in index_list, and singles is all of the + other elements, in original order. + """ + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + return ranges, singles + + +GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') + + +def pack_range(key, packing, grad_vars, rng): + """Form the concatenation of a specified range of gradient tensors. + + Args: + key: Value under which to store meta-data in packing that will be used + later to restore the grad_var list structure. + packing: Dict holding data describing packed ranges of small tensors. + grad_vars: List of (grad, var) pairs for one tower. + rng: A pair of integers giving the first, last indices of a consecutive + range of tensors to be packed. + + Returns: + A tensor that is the concatenation of all the specified small tensors. + """ + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with ops.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with ops.device(g.device): + members.append(array_ops.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with ops.device(members[0].device): + return array_ops.concat(members, 0) + + +def unpack_grad_tuple(gv, gpt): + """Unpack a previously packed collection of gradient tensors. + + Args: + gv: A (grad, var) pair to be unpacked. + gpt: A GradPackTuple describing the packing operation that produced gv. + + Returns: + A list of (grad, var) pairs corresponding to the values that were + originally packed into gv, maybe following subsequent operations like + reduction. + """ + elt_widths = [x.num_elements() for x in gpt.shapes] + with ops.device(gv[0][0].device): + with ops.name_scope('unpack'): + splits = array_ops.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), + gpt.vars[idx])) + return unpacked_gv + + +def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): + """Concatenate small gradient tensors together for reduction. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. + max_bytes: Int giving max number of bytes in a tensor that + may be considered small. + max_group: Int giving max number of small tensors that may be + concatenated into one new tensor. + + Returns: + new_tower_grads, packing where new_tower_grads is identical to + tower_grads except that all feasible small_tensors have been removed + from their places and concatenated into larger tensors that are + now in the front of the list for each tower, and packing contains + the data necessary to restore the tower_grads structure. + + Look through the first tower for gradients of the same type (float), + and small size, that are all sequential. For each such group, + replace by a new tensor that is a flattened concatenation. Note + that the corresponding variable will be absent, which doesn't matter + because it isn't used during all-reduce. + + Requires: + Every gv_list in towers must have isomorphic structure including identical + tensor sizes and types. + """ + small_indices = [] + large_indices = [] + for idx, (g, _) in enumerate(tower_grads[0]): + if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: + small_indices.append(idx) + else: + large_indices.append(idx) + small_ranges, small_singles = extract_ranges( + small_indices, range_size_limit=max_group) + large_indices = sorted(large_indices + small_singles) + num_gv = len(tower_grads[0]) + packing = {} + if small_ranges: + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing + else: + return tower_grads, None + + +def unpack_small_tensors(tower_grads, packing): + """Undo the structure alterations to tower_grads done by pack_small_tensors. + + Args: + tower_grads: List of List of (grad, var) tuples. + packing: A dict generated by pack_small_tensors describing the changes + it made to tower_grads. + + Returns: + new_tower_grads: identical to tower_grads except that concatenations + of small tensors have been split apart and returned to their original + positions, paired with their original variables. + """ + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + gv_list = list(gv_list) + new_gv_list = gv_list[num_packed:] + for i in xrange(0, num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2b49b8f4ef2937c7cffbdbd36ca50f6b0db8c1b0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -0,0 +1,127 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show that DistributionStrategy works with canned Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.summary.writer import writer_cache + + +class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, + parameterized.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def dataset_input_fn(self, x, y, batch_size, shuffle): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + if shuffle: + dataset = dataset.shuffle(batch_size) + dataset = dataset.repeat(10).batch(batch_size) + return dataset + + return input_fn + + @combinations.generate( + combinations.combine( + mode=['graph'], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_without_prefetch + ])) + def test_complete_flow_with_mode(self, distribution): + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + train_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices), + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, y=data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, batch_size=batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + estimator = dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=label_dimension, + model_dir=self._model_dir, + # TODO(isaprykin): Work around the colocate_with error. + dnn_optimizer=adagrad.AdagradOptimizer(0.001), + linear_optimizer=adagrad.AdagradOptimizer(0.001), + config=run_config.RunConfig(train_distribute=distribution)) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cbfd17850212a1c007e2edb9dd3986b3109f040d --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -0,0 +1,30 @@ +# Example TensorFlow models that use DistributionStrategy for training. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "simple_estimator_example", + srcs = ["simple_estimator_example.py"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "simple_tfkeras_example", + srcs = [ + "simple_tfkeras_example.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py new file mode 100644 index 0000000000000000000000000000000000000000..00c25c7a2482a559c8b94ff3be86c4961dfb439f --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple example to test the a DistributionStrategy with Estimators. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def build_model_fn_optimizer(): + """Simple model_fn with optimizer.""" + # TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is + # done? + optimizer = tf.train.GradientDescentOptimizer(0.2) + + def model_fn(features, labels, mode): # pylint: disable=unused-argument + """model_fn which uses a single unit Dense layer.""" + # You can also use the Flatten layer if you want to test a model without any + # weights. + layer = tf.layers.Dense(1, use_bias=True) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + def loss_fn(): + y = tf.reshape(logits, []) - tf.constant(1.) + return y * y + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss_fn()) + + assert mode == tf.estimator.ModeKeys.TRAIN + + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss_fn(), global_step=global_step) + return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op) + + return model_fn + + +def main(_): + distribution = tf.contrib.distribute.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + config = tf.estimator.RunConfig(train_distribute=distribution) + + def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + labels = tf.data.Dataset.from_tensors([1.]).repeat(10) + return tf.data.Dataset.zip((features, labels)) + + estimator = tf.estimator.Estimator( + model_fn=build_model_fn_optimizer(), config=config) + estimator.train(input_fn=input_fn, steps=10) + + eval_result = estimator.evaluate(input_fn=input_fn) + print("Eval result: {}".format(eval_result)) + + def predict_input_fn(): + predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + return predict_features + + predictions = estimator.predict(input_fn=predict_input_fn) + # TODO(anjalsridhar): This returns a generator object, figure out how to get + # meaningful results here. + print("Prediction results: {}".format(predictions)) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py new file mode 100644 index 0000000000000000000000000000000000000000..b87224251ca3844fc81c6f32a893d2c71664a955 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An example tf.keras model that is trained using MirroredStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from sys import argv +import numpy as np +import tensorflow as tf + + +def input_fn(): + x = np.random.random((1024, 10)) + y = np.random.randint(2, size=(1024, 1)) + x = tf.cast(x, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(10) + dataset = dataset.batch(32) + return dataset + + +def main(args): + if len(args) < 2: + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example./') + return + + print('Using %s to store checkpoints.' % args[1]) + + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) + optimizer = tf.train.GradientDescentOptimizer(0.2) + + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) + model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + + model.compile(loss='binary_crossentropy', optimizer=optimizer) + model.summary() + tf.keras.backend.set_learning_phase(True) + keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, config=config, model_dir=args[1]) + + keras_estimator.train(input_fn=input_fn, steps=10) + eval_result = keras_estimator.evaluate(input_fn=input_fn) + print('Eval result: {}'.format(eval_result)) + +if __name__ == '__main__': + tf.app.run(argv=argv) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa90df79bbcd621fe7b7d0da04256b7a59d5bfe --- /dev/null +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -0,0 +1,279 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.ops.losses import losses_impl + + +class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers() + + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph", "eager"]))) + def testOptimizerInsideModelFn(self, distribution, optimizer_fn): + created_variables = [] + trainable_variables = [] + + def appending_creator(next_creator, *args, **kwargs): + v = next_creator(*args, **kwargs) + created_variables.append(v.name) + if "trainable" in kwargs and kwargs["trainable"]: + trainable_variables.append(v.name) + return v + + # Creator scope needs to be set before it's used inside + # `distribution.scope`. + with variable_scope.variable_creator_scope( + appending_creator), distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=True, + create_optimizer_inside_model_fn=True) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + def get_expected_variables(optimizer_fn, num_parameter_devices): + variables_map = { + "GradientDescent": ["dense/kernel", "dense/bias"], + "Adam": [ + "dense/kernel", "dense/bias", "beta1_power", "beta2_power", + "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", + "dense/bias/Adam_1" + ] + } + variables = variables_map[optimizer_fn().get_name()] + variables.extend([ + v + "/replica_{}".format(replica) + for v in variables + for replica in range(1, num_parameter_devices) + ]) + return set([v + ":0" for v in variables]) + + self.assertEqual( + get_expected_variables(optimizer_fn, + len(distribution.parameter_devices)), + set(created_variables)) + + @combinations.generate( + combinations.times(combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + momentum=[0.8, 0.9, 0.99], + renorm=[False, True]))) + def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, + renorm): + """Verifies that moving mean updates are reduced across towers.""" + with distribution.scope(): + num_towers = len(distribution.worker_devices) + model_fn, dataset, batchnorm = batchnorm_example( + optimizer_fn, + batch_per_epoch=num_towers, + momentum=momentum, + renorm=renorm) + + # Disable prefetching since that makes the specific input on each device + # to be non deterministic, and this test relies on specific input being + # on each device. + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + distribution._prefetch_on_device = False + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group( + distribution.unwrap( + distribution.call_for_each_tower( + model_fn, + iterator.get_next(), + run_concurrently=batchnorm.built)) + + ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + expected_moving_means = [0.] * 8 + + def averaged_batch_mean(i): + # Each batch has shape [16, 8] where the ith element in jth list is + # (8 * j + i + tower_id * 100). So the batch mean in each tower is + # (60 + i + tower_id * 100). So here comes its batch mean over all + # towers: + return 60. + i + (num_towers - 1.) / 2. * 100. + + for _ in range(10): + run_step() + moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + + # We make sure that the moving_mean is updated as if the sample mean is + # calculated over all towers. + for i, expected_moving_mean in enumerate(expected_moving_means): + expected_moving_means[i] -= (( + expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) + self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + + @combinations.generate( + combinations.times( + combinations.combine( + distribution=[combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn], + loss_reduction=[losses_impl.Reduction.SUM, + losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, + use_callable_loss): + with distribution.scope(): + all_vars = [] + + def model_fn(x, y): + + def loss_fn(): + # Use fixed initialization to make the steps deterministic. + w = variable_scope.get_variable("w", initializer=[[2.]]) + all_vars.append(w) + predict = math_ops.matmul(x, w) + return losses_impl.mean_squared_error( + y, predict, reduction=loss_reduction) + + optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + dataset = dataset_ops.Dataset.zip((features, labels)).repeat() + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, *iterator.get_next(), run_concurrently=False)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + self.assertEqual(distribution.num_towers, len(all_vars)) + v = all_vars[0] + self.assertTrue(all([v is vi for vi in all_vars[1:]])) + weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + # Our model is: + # predict = x * w + # loss = (predict - y)^2 + # dloss/dpredict = 2*(predict - y) + # dloss/dw = 2 * x^T @ (predict - y) + # For our batch size of 2, assuming sum loss reduction: + # x = [2, 7] + # y = [6, 21] + # w_initial = 2 + # predict = [4, 14] + # predict - y = [-2, -7] + # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106 + # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 + # with sum loss reduction, or 10.6 with mean. + if loss_reduction == losses_impl.Reduction.SUM: + # Note that the "distribution.num_towers" factor will go away once + # we split the input across towers, instead of pulling a complete + # batch of input per tower. + self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001) + else: + # One of the mean loss reductions. + self.assertNear(weight, 2 + 10.6, 0.0001) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0edb3a11df7788991ca14f957494d87593a449 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -0,0 +1,497 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import six + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.contrib.distribute.python import values +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import coordinator +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) + return cpu_device.to_string() + + +class _RequestedStop(Exception): + pass + + +class MirroredStrategy(distribute_lib.DistributionStrategy): + """Mirrors vars to distribute across multiple devices on a single machine. + + This strategy uses one tower per device and sync replication. + """ + + def __init__(self, + devices=None, + num_gpus=None, + cross_tower_ops=None, + prefetch_on_device=None): + super(MirroredStrategy, self).__init__() + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + + assert devices, "Must specify at least one device." + assert len(set(devices)) == len(devices), ( + "No duplicates allowed in `devices` argument.") + # TODO(josh11b): Require at least 2 devices? + self._devices = devices + self._canonical_device_set = set( + [device_util.canonicalize(d) for d in devices]) + self._device_index = values.PerDevice( + dict((d, i) for i, d in enumerate(devices))) + self._cross_tower_ops = cross_tower_ops + self._prefetch_on_device = prefetch_on_device + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a mirrored variable. See `DistributionStrategy.scope`.""" + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + tower_local = kwargs.pop("tower_local_reduce_method", None) + if tower_local is not None: + kwargs["trainable"] = False + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + kwargs["name"] = "%s/replica_%d" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + initial_value = index[devices[0]].value() + else: + initial_value = index[devices[0]].initial_value + kwargs["initial_value"] = array_ops.identity(initial_value) + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.DistributedVariable) + index[d] = v + + if tower_local is None: + result = values.MirroredVariable(index, index[devices[0]]) + else: + result = values.TowerLocalVariable( + index, index[devices[0]], tower_local) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + def distribute_dataset(self, dataset): + per_device_dataset = values.PerDeviceDataset( + dataset, self._devices, self._prefetch_on_device) + return per_device_dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + # TODO(josh11b): In eager mode, use one thread per device, or async mode. + return self._get_cross_tower_ops().broadcast(tensor, destinations or + self._devices) + + def _call_for_each_tower(self, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times for when called for different devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator( + clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(self._devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( + self, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup( + {t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + def map(self, map_over, fn, *args, **kwargs): + # TODO(josh11b): In eager mode, use one thread per device. + index = {} + i = 0 + for m in map_over: + d = self._devices[i % len(self._devices)] + with ops.device(d): + l = index.get(d, []) + l.append(fn(m, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + index[d] = l + # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput + # in addition to PerDevice data. + return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + + def configure(self, session_config=None): + if self._cross_tower_ops is None: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) + + def _get_cross_tower_ops(self): + if self._cross_tower_ops is None: + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + return self._cross_tower_ops + + def _reduce(self, method_string, value, destinations): + if len(self._devices) == 1 and not isinstance(value, values.PerDevice): + value = values.PerDevice({self._devices[0]: value}) + assert isinstance(value, values.PerDevice) + + return self._get_cross_tower_ops().reduce( + method_string, value, destinations=destinations) + + def _batch_reduce(self, method_string, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(method_string, + value_destination_pairs) + + def _update(self, var, fn, *args, **kwargs): + # TODO(josh11b): Also support TowerLocalVariables here? If so, args and + # kwargs don't need to be mirrored. + assert isinstance(var, values.MirroredVariable) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + assert isinstance(colocate_with, list) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d in colocate_with: + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(*values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + assert isinstance(destination, six.string_types) + if isinstance(val, values.TowerLocalVariable): + val = self.reduce(val.reduce_method, val, destinations=destination) + with ops.device(destination): + return fn(self.unwrap(val)[0]) + + assert isinstance(val, values.Mirrored), ( + "val = %s (type %s)" % (val, val.__class__.__name__)) + if val.on_device(destination): + with ops.device(destination): + # Use an identity here to make sure we are returning a tensor + # instead of e.g. a variable object. + return array_ops.identity(fn(val.get(destination))) + device = None + for d in self._devices: + if val.on_device(d): + device = d + break + assert device is not None, ( + "Could not find destination %s in list of devices %s." % + (destination, val.devices)) + with ops.device(device): + v = fn(val.get(device)) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_device_set: + return [val.get(device=d) for d in self._devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + @property + def is_single_tower(self): + return len(self._devices) == 1 + + @property + def num_towers(self): + return len(self._devices) + + def _worker_device_index(self): + return self._device_index + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._devices) + + @property + def parameter_devices(self): + return list(self._devices) + + def non_slot_devices(self, var_list): + del var_list + return list(self._devices) + + def _get_devices_from(self, colocate_with=None): + if colocate_with is None: + return self._devices + elif isinstance(colocate_with, values.DistributedValues): + # pylint: disable=protected-access + return list(colocate_with._index.keys()) + elif isinstance(colocate_with, six.string_types): + return [colocate_with] + else: + return colocate_with + + class _MirroredTowerThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, + **kwargs): + super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access + self.coord = coord + self.distribution = dist + self.device = device + self.tower_id = dist.worker_devices.index(device) + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # TowerContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_tower_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + # pylint: disable=protected-access + self.context_mode = context.context()._eager_context.mode + if not context.context()._context_handle: + context.context()._initialize_handle_and_devices() + self.context_device_policy = ( + pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + context.context()._context_handle)) + self.graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] + self._captured_var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._captured_name_scope = self.graph.get_name_scope() + if self._captured_name_scope: + self._captured_name_scope += "/" + if self.tower_id > 0: + if not self._captured_name_scope: + self._captured_name_scope = "" + self._captured_name_scope += "tower_%d/" % self.tower_id + + def run(self): + # pylint: disable=protected-access + self.graph._variable_creator_stack = self._variable_creator_stack + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + with self.coord.stop_on_exception(), \ + context.context()._mode(self.context_mode), \ + context.context().device_policy(self.context_device_policy), \ + self.graph.as_default(), \ + MirroredTowerContext(self.distribution, self.tower_id), \ + ops.device(self.device), \ + ops.name_scope(self._captured_name_scope), \ + variable_scope.variable_scope( + self._captured_var_scope, reuse=self.tower_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + +class MirroredTowerContext(distribute_lib.TowerContext): + """TowerContext used in MirroredStrategy.call_for_each_tower(). + + Opened in `_MirroredTowerThread`, to allow the user to invoke + `MirroredStrategy`'s specific implementation of `merge_call()`, + which works by delegating the function and its arguments to + the main thread (the one that invoked + `MirroredStrategy.call_for_each_tower()`). + """ + + def _merge_call(self, fn, *args, **kwargs): + """Delegate to the main thread to actually perform merge_call().""" + t = threading.current_thread() # a _MirroredTowerThread + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def device(self): + distribute_lib.require_tower_context(self) + return self._distribution_strategy.worker_devices[self._tower_id] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9f06da8e2ed185c2c32f79a5a4f5407165fb1d --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -0,0 +1,435 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multi-GPU tests for MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib + +GPU_TEST = "test_gpu" in sys.argv[0] + + +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + if context.num_gpus() > 1: + devices = ["/device:GPU:0", "/device:GPU:1"] + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + return mirrored_strategy.MirroredStrategy(devices) + + def testMinimizeLossEager(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + soft_placement = not GPU_TEST + print("testMinimizeLossGraph soft_placement:", soft_placement) + self._test_minimize_loss_graph( + self._get_distribution_strategy(), soft_placement=soft_placement) + + def testMapReduce(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_tower_id(self._get_distribution_strategy()) + + def testNumTowers(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self.assertEqual(2, self._get_distribution_strategy().num_towers) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testRunRegroupError(self): + + def run_fn(device_id): + # Generates a list with different lengths on different devices. + # Will fail in _regroup() (if more than one device). + return list(range(device_id)) + + dist = self._get_distribution_strategy() + with dist.scope(), self.assertRaises(AssertionError): + dist.call_for_each_tower(run_fn, dist.worker_device_index) + + @test_util.run_in_graph_and_eager_modes() + def testReduceToCpu(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return device_id + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + expected = sum(range(len(dist.worker_devices))) + self.assertEqual(expected, self.evaluate(unwrapped[0])) + + +@test_util.with_c_api +class MirroredStrategyVariableCreationTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSingleVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + # This variable should be created only once across the threads because of + # special variable_creator functions used by `dist.call_for_each_tower`. + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnnamedVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v = variable_scope.variable(1.0) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # Default name of "Variable" will be used. + self.assertEquals("Variable:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + for i in range(5): + vs.append(variable_scope.variable(1.0, name="foo" + str(i))) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for i, v in enumerate(result): + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals("foo" + str(i) + ":0", v.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariablesWithSameCanonicalName(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + vs.append(variable_scope.variable(1.0, name="foo/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) + vs.append(variable_scope.variable(1.0, name="foo/bar_1")) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for v in result: + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals(4, len(result)) + self.assertEquals("foo/bar:0", result[0].name) + self.assertEquals("foo_1/bar:0", result[1].name) + self.assertEquals("foo_1/bar_1:0", result[2].name) + self.assertEquals("foo/bar_1:0", result[3].name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableWithSameCanonicalNameAcrossThreads(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(device_id): + v = variable_scope.variable(1.0, name="foo_" + str(device_id)) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # The resulting mirrored variable will use the name from the first device. + self.assertEquals("foo_0:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithLayers(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(features): + with variable_scope.variable_scope("common"): + layer1 = core.Dense(1) + layer1(features) + layer2 = core.Dense(1) + layer2(features) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + layer3 = core.Dense(1) + layer3(features) + return [(layer1.kernel, layer1.bias), + (layer2.kernel, layer2.bias), + (layer3.kernel, layer3.bias)] + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + features = dist.distribute_dataset(features).get_next() + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, features, run_concurrently=False) + suffixes = ["", "_1", "_2"] + for (kernel, bias), suffix in zip(result, suffixes): + self.assertIsInstance(kernel, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertIsInstance(bias, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithGetVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.get_variable("var-thread0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var-thread1", [1]) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.get_variable("var-thread2", [1]) + + return v0, v1, v2 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with variable_scope.variable_scope("main"): + v = variable_scope.get_variable("var-main0", [1]) + self.assertEquals("main/var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(3, len(result)) + v0, v1, v2 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("main/var-thread0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("main/common/var-thread1:0", v1.name) + self.assertIsInstance(v2, values.MirroredVariable) + self.assertEquals("main/common/var-thread2:0", v2.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testThreeDevices(self): + self._skip_eager_if_gpus_less_than(2) + + def model_fn(): + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNonMatchingVariableCreation(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(name): + v = variable_scope.variable(1.0, name=name) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + names = values.DistributedValues({ + "/device:CPU:0": "foo", + "/device:GPU:0": "bar" + }) + with self.assertRaises(RuntimeError): + _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariable(self): + self._skip_eager_if_gpus_less_than(1) + + all_v_sum = {} + all_v_mean = {} + + def model_fn(device_id): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + with tower_context.tower_local_var_scope("mean"): + v_mean = variable_scope.variable(4.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) + updates = [v_sum.assign_add(2.0 + device_id), + v_mean.assign(6.0 * device_id)] + all_v_sum[device_id] = v_sum + all_v_mean[device_id] = v_mean + return updates, v_sum, v_mean + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + # Create "sum" and "mean" versions of TowerLocalVariables. + ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + # Should see the same wrapping instance in all towers. + self.assertIs(all_v_sum[0], ret_v_sum) + self.assertIs(all_v_mean[0], ret_v_mean) + for i in range(1, dist.num_towers): + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Apply updates + self.evaluate(variables.global_variables_initializer()) + self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + expected_sum = 0.0 + expected_mean = 0.0 + for i, d in enumerate(dist.worker_devices): + # Test access within a device scope, should see different values. + with ops.device(d): + v_sum_value = self.evaluate(ret_v_sum.read_value()) + v_mean_value = self.evaluate(ret_v_mean.read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected + + # fetch() should return the value you get by applying the + # reduction across all towers. + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + expected_mean /= len(dist.worker_devices) + self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + + # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not + # testing this in eager mode. + + def testNameScope(self): + def model_fn(): + with ops.name_scope("foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(1.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("main/foo/" + name + ":0", v0.name) + self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name) + + def testWithDefaultName(self): + def model_fn(): + with ops.name_scope(None, "foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(2.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("foo/" + name + ":0", v0.name) + self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ef0ecc77a8e8432dfa4eb6da7c324b371dab70 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib + + +@test_util.with_c_api +class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +@test_util.with_c_api +class VariableCreatorStackTest(test.TestCase): + + def testCreatorStacksAreThreadLocal(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + + def model_fn(device_id): + assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + str(device_id) + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + dist.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = dist.call_for_each_tower(model_fn, dist.worker_device_index) + result = dist.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEquals(expected, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..7644acedc99361d7287a91832d76bc68cbc6ac0a --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Monitor is responsible for training, checkpointing and recovery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.ops import variables + + +class Monitor(object): + """Executes training steps, recovers and checkpoints. + + Note that this class is particularly preliminary, experimental, and + expected to change. + """ + # TODO(isaprykin): Support step functions that need multiple session calls. + # TODO(isaprykin): Support extra arguments to the step function. + # TODO(isaprykin): Support recovery, checkpointing and summaries. + + def __init__(self, step_callable, session=None): + """Initialize the Monitor with components for executing training steps. + + Args: + step_callable: a training `Step` that's capable of signaling when done. + session: a `Session` instance that's needed for graph mode. + + Raises: + ValueError: if `session` was provided for eager mode or not provided for + graph mode. + """ + if context.executing_eagerly(): + if session is not None: + raise ValueError("Should not provide a `session` in Eager mode.") + self._run_step = step_callable + else: + if session is None: + raise ValueError("Should provide a `session` in Graph mode.") + self._run_step = session.make_callable(step_callable()) + session.run(variables.global_variables_initializer()) + + def run_steps(self, num_steps=None): + step = 0 + while num_steps is None or step < num_steps: + try: + self._run_step() + step += 1 + except errors.OutOfRangeError: + break diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8277e1e7919e86ef616b31d0986589dcc9c49bbd --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Monitor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import monitor as monitor_lib +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.training import gradient_descent + + +class MonitorTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example(optimizer_fn, distribution) + + if context.executing_eagerly(): + monitor = monitor_lib.Monitor(single_loss_step, None) + else: + with self.test_session() as sess: + monitor = monitor_lib.Monitor(single_loss_step, sess) + + monitor.run_steps(1) + + self.assertEqual(1, len(layer.trainable_variables)) + mirrored_weight_variable = layer.trainable_variables[0] + start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = abs(numpy.array(start_error) - 1) + + monitor.run_steps(9) + end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = abs(numpy.array(end_error) - 1) + self.assertGreaterEqual(start_error, end_error) + + def testPassingASessionInEager(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with self.test_session() as sess: + with self.assertRaisesRegexp(ValueError, "Should not provide"): + _ = monitor_lib.Monitor(step_function, sess) + + def testNotPassingASessionInGraph(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with context.graph_mode(), ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, "Should provide"): + _ = monitor_lib.Monitor(step_function, session=None) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..39c49442b9c3245cfd0b67a51be68773a6fd3ff4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class OneDeviceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +class OneDeviceStrategy(distribute_lib.DistributionStrategy): + """A distribution strategy for running on a single device.""" + # TODO(josh11b): Do we wrap values in types to generate errors if you are + # doing something that won't work with other DistributionStrategy + # implementations? + + def __init__(self, device): + super(OneDeviceStrategy, self).__init__() + self._device = device + + def _create_variable(self, next_creator, *args, **kwargs): + # No need to distinguish tower-local variables when not mirroring, + # we just enforce that they are not trainable. + if kwargs.pop("tower_local_reduce_method", None) is not None: + kwargs["trainable"] = False + + colocate_with = kwargs.pop("colocate_with", None) + if colocate_with is None: + with ops.device(self._device): + return next_creator(*args, **kwargs) + if isinstance(colocate_with, six.string_types): + with ops.device(colocate_with): + return next_creator(*args, **kwargs) + if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + isinstance(colocate_with[0], six.string_types)): + with ops.device(colocate_with[0]): + return next_creator(*args, **kwargs) + with ops.colocate_with(colocate_with): + return next_creator(*args, **kwargs) + + def distribute_dataset(self, dataset): + if context.executing_eagerly(): + return datasets.Iterator(dataset) + else: + return dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + return tensor + + def _call_for_each_tower(self, fn, *args, **kwargs): + # We don't run `fn` in multiple threads in OneDeviceStrategy. + kwargs.pop("run_concurrently", None) + with ops.device(self._device), _OneDeviceTowerContext(self): + return fn(*args, **kwargs) + + def map(self, map_over, fn, *args, **kwargs): + with ops.device(self._device): + return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) + + def _reduce(self, method_string, value, destinations): + if not isinstance(value, values.MapOutput): + return value + l = value.get() + assert l + with ops.device(self._device): + if method_string == "sum": + return math_ops.add_n(l) + elif method_string == "mean": + return math_ops.add_n(l) / len(l) + else: + assert False + + def _update(self, var, fn, *args, **kwargs): + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(var, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + del colocate_with + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(*args, **kwargs) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + with ops.device(self._device): + v = fn(val) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, value): + return [value] + + @property + def is_single_tower(self): + return True + + @property + def num_towers(self): + return 1 + + @property + def worker_devices(self): + return [self._device] + + @property + def parameter_devices(self): + return [self._device] + + def non_slot_devices(self, var_list): + del var_list + return [self._device] + + def _worker_device_index(self): + return 0 + + +class _OneDeviceTowerContext(distribute_lib.TowerContext): + + def __init__(self, distribution_strategy): + distribute_lib.TowerContext.__init__( + self, distribution_strategy, tower_id=0) + + @property + def device(self): + return self._distribution_strategy.worker_devices[0] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7101ed0756f44b846f10ddc6d429afe005a2f196 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class OneDeviceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +@test_util.with_c_api +class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0912b625f44342d22acc0ce9bb52a6b632c75a0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables + + +class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group(distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built))) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ddf3cece1c3fa549d6d2999a9bff9671fcdd76 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,166 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._devices = devices + input_iterator = input_dataset.make_one_shot_iterator() + input_iterator_handle = input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, input_iterator.output_types, input_iterator.output_shapes, + input_iterator.output_classes) + return remote_iterator.get_next() + + target_device = gen_dataset_ops.iterator_get_device( + input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size) + self._buffering_resources.append(buffer_resource_handle) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator(self._input_dataset, self._devices, + self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + raise NotImplementedError("`prefetch_to_devices()` is not currently " + "compatible with initializable iterators. Use " + "`make_one_shot_iterator()` instead.") + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed16f4607881f2864479c04b4c25e95d9fa1850 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + with self.test_session() as sess: + for _ in range(5): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(10)), set(output)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator.py b/tensorflow/contrib/distribute/python/shared_variable_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..a7083e279f20803b227dcd52f6420ae832aa2df4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to re-use variables created on first device on subsequent devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +_VARIABLE_UNIQUIFYING_REGEX = re.compile(r"_\d/") +_VARIABLE_UNIQUIFYING_REGEX_AT_END = re.compile(r"_\d$") + + +def _canonicalize_variable_name(name): + # If no name is specified, uses default name "Variable". + if name is None: + return "Variable" + # Replace all instances of "_/" with "/" + name = _VARIABLE_UNIQUIFYING_REGEX.sub("/", name) + # Replace any instances of "_" at the end of the string with "" + name = _VARIABLE_UNIQUIFYING_REGEX_AT_END.sub("", name) + return name + + +def make_fn(shared_variable_store, device_id): + """Construct the variable creator function for device `device_id`. + + Constructs custom variable creator functions for the given device. + On first device (device_id == 0), it creates the variable using the + `next_creator`, and stores it in the provided `shared_variable_store`. + On all other devices (device_id > 0), it tries to re-use the variable + already created with the same name. If no such variable exists, it throws an + error. + Additionally, we de-uniquify variable names before checking for matches. This + helps re-use variables which are intended to be the same but have different + names due to variable uniquification happening upstream. Since this might + mean we may have multiple variables with the same canonical name, we store + them in a list per canonical name and return them in the same order as well. + + Args: + shared_variable_store: A dictionary that we will use to store variables + created on the first device, and re-used by creators for other devices. + device_id: Integer index of the device whose creator should be + constructed. + + Returns: + An appropriate creator function based on device_id. + + """ + variable_scope_access_index = {} + assert isinstance(device_id, int) + + def create_new_variable(next_creator, *args, **kwargs): + """Create the variable using `next_creator` and store it.""" + canonical_name = _canonicalize_variable_name(kwargs.get("name")) + v = next_creator(*args, **kwargs) + + if canonical_name not in shared_variable_store: + shared_variable_store[canonical_name] = [] + shared_variable_store[canonical_name].append(v) + return v + + def reuse_variable(next_creator, *args, **kwargs): + """Re-use existing variable from store with same name (in order).""" + del next_creator, args + name = kwargs.get("name") + canonical_name = _canonicalize_variable_name(name) + + try: + variable_index = variable_scope_access_index.get(canonical_name, 0) + v = shared_variable_store[canonical_name][variable_index] + # TODO(priyag): Make this variable re-use more robust by adding checks + # that the requested shape and dtype match the existing variable. + variable_scope_access_index[canonical_name] = variable_index + 1 + return v + except (KeyError, IndexError): + raise RuntimeError( + "Tried to create variable {} with mismatching name on device {}". + format(name, device_id)) + + if device_id == 0: + return create_new_variable + else: + return reuse_variable diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..713494d603b855be2863af9f24ab98d4cf048042 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SharedVariableCreator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope + + +class CanonicalizeVariableNameTest(test.TestCase): + + def _canonicalize(self, name): + return shared_variable_creator._canonicalize_variable_name(name) + + def testNoName(self): + self.assertEquals("Variable", self._canonicalize(None)) + + def testPatternInMiddle(self): + self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz")) + + def testPatternAtEnd(self): + self.assertEquals("foo", self._canonicalize("foo_1")) + + def testWrongPatterns(self): + self.assertEquals("foo_1:0", self._canonicalize("foo_1:0")) + self.assertEquals("foo1", self._canonicalize("foo1")) + self.assertEquals("foo_a", self._canonicalize("foo_a")) + + +@test_util.with_c_api +class SharedVariableCreatorTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSharedVariable(self): + + shared_variable_store = {} + num_devices = 3 + creator_fns = [] + for i in range(num_devices): + creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) + creator_fns.append(creator_fn) + + with variable_scope.variable_creator_scope(creator_fns[0]): + v0 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[1]): + v1 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[2]): + v2 = variable_scope.variable(1.0, name="foo") + + # v1 and v2 should be same as v0 + self.assertIs(v1, v0) + self.assertIs(v2, v0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py new file mode 100644 index 0000000000000000000000000000000000000000..cef5fd2f8943d348a0721cd72032bf6cb2199ad9 --- /dev/null +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple network to use in tests and examples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import step_fn +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.layers import core +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def single_loss_example(optimizer_fn, distribution, use_bias=False): + """Build a very simple network to use in tests and examples.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=use_bias) + + def loss_fn(x): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer, + distribution) + + # Layer is returned for inspecting the kernels in tests. + return single_loss_step, layer + + +def minimize_loss_example(optimizer_fn, + use_bias=False, + use_callable_loss=True, + create_optimizer_inside_model_fn=False): + """Example of non-distribution-aware legacy code.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # An Optimizer instance is created either outside or inside model_fn. + outer_optimizer = None + if not create_optimizer_inside_model_fn: + outer_optimizer = optimizer_fn() + + layer = core.Dense(1, use_bias=use_bias) + + def model_fn(x): + """A very simple model written by the user.""" + + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + optimizer = outer_optimizer or optimizer_fn() + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + return model_fn, dataset, layer + + +def batchnorm_example(optimizer_fn, + batch_per_epoch=1, + momentum=0.9, + renorm=False): + """Example of non-distribution-aware legacy code with batch normalization.""" + # input shape is [16, 8], input values are increasing in both dimensions. + dataset = dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + optimizer = optimizer_fn() + batchnorm = normalization.BatchNormalization( + renorm=renorm, momentum=momentum, fused=False) + + def model_fn(x): + + def loss_fn(): + y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) + loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + return loss + + # Callable loss. + return optimizer.minimize(loss_fn) + + return model_fn, dataset, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..82514c64be40b421c4a9887932f2cfb8e1ac4be0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The step function abstraction represents a single training step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.training import optimizer as optimizer_lib + + +class Step(object): + """Interface for performing each step of a training algorithm.""" + + def __init__(self, distribution): + self._distribution = distribution + + @property + def distribution(self): + return self._distribution + + def __call__(self): + """Perform one step of this training algorithm.""" + return self.step(self.inputs()) + + def inputs(self): + """For the generating the input to be passed to `step()`.""" + raise NotImplementedError("must be implemented in descendants") + + def step(self, inputs): + """Perform the main computation of this training algorithm.""" + raise NotImplementedError("must be implemented in descendants") + + +class StandardInputStep(Step): + """Step with a standard implementation of input handling. + + Args: + input_dataset: a tf.data Dataset that provides input. + """ + + def __init__(self, input_dataset, distribution): + Step.__init__(self, distribution) + self._distributed_input = distribution.distribute_dataset(input_dataset) + + def inputs(self): + return self._distributed_input.get_next() + + +class StandardSingleLossStep(StandardInputStep): + """A step function that implements a training step for a feed forward network. + + An instance of this class is intended to be used as a callable: + + ```python + ... + step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) + step.initialize(distribution) + + # Run a single training step on a given DistributionStrategy: + step(distribution) + ... + ``` + + Args: + input_dataset: a tf.data Dataset that provides input. + loss_fn: a function that returns loss. + optimizer: an optimizer that implements an update rule. + distribution: a `DistributionStrategy` object. + """ + + def __init__(self, input_dataset, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, input_dataset, distribution) + self._loss_fn = loss_fn + self._optimizer = optimizer + self._is_run_concurrently = False + + def step(self, inputs): + with self._distribution.scope(): + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, inputs, run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step sequentially, + # so that layers.build() is not executed in parallel. Otherwise, multiple + # sets of mirrored variables are going to be created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75c5ec9659d193e77d219ba79977615d58841d64 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import variables + + +class SingleLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example( + optimizer_fn, distribution, use_bias=True) + + if context.executing_eagerly(): + run_step = single_loss_step + else: + with self.test_session() as sess: + run_step = sess.make_callable(single_loss_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4ad9f146bc1d6a987fbeecbb05122946137154 --- /dev/null +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -0,0 +1,225 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for testing DistributionStrategy descendants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer + + +class _TestException(Exception): + pass + + +# May be the argument to either distribution.call_for_each_tower() or +# get_tower_context().merge_call() +def _raise_exception_fn(_=None): + raise _TestException() + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that raises an exception. +def _merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that raises an exception. +def _call_raises_fn(dist): + dist.call_for_each_tower(_raise_exception_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, +# calls a get_tower_context().merge_call() that calls a +# call_for_each_tower() that raises an exception. +def _merge_call_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_raises_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that calls a +# get_tower_context().merge_call() that raises an exception. +def _call_merge_raises_fn(dist): + dist.call_for_each_tower(_merge_raises_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that calls a call_for_each_tower() that +# calls a get_tower_context().merge_call() that raises an exception. +def _merge_call_merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + + +class DistributionTestBase(test.TestCase): + """Some tests that should work with any DistributionStrategy.""" + + def _test_minimize_loss_eager(self, d): + with d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a + # common `implicit_grad` function and put it in DistributionStrategy. + grad_fn = backprop.implicit_grad(loss) + grad_fn = optimizer.get_filtered_grad_fn(grad_fn) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + # control_dependencies irrelevant but harmless in eager execution + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + for i in range(10): + b, a = step() + if i == 0: + before, = b # pylint: disable=unbalanced-tuple-unpacking + after, = a # pylint: disable=unbalanced-tuple-unpacking + + error_before = abs(before.numpy() - 1) + error_after = abs(after.numpy() - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_minimize_loss_graph(self, d, soft_placement=False): + config = config_pb2.ConfigProto() + config.allow_soft_placement = soft_placement + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + with context.graph_mode(), \ + ops.Graph().as_default(), \ + self.test_session(config=config) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + grad_fn = backprop.implicit_grad(loss) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + before_out, after_out = step() + variables.global_variables_initializer().run() + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_map_reduce(self, d, in_graph=None): + with d.scope(): + map_in = [constant_op.constant(i) for i in range(10)] + map_out = d.map(map_in, lambda x, y: x * y, 2) + observed = d.fetch(d.reduce("sum", map_out)) + expected = 90 # 2 * (0 + 1 + ... + 9) + self.assertEqual(expected, observed.numpy()) + + def _test_device_index(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(device_id): + self.assertLess(device_id, len(d.worker_devices)) + self.assertFalse(expected_devices[device_id]) + expected_devices[device_id] = True + + d.call_for_each_tower(mark_devices_fn, d.worker_device_index) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_tower_id(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(): + tower_id = distribute_lib.get_tower_context().tower_id + self.assertLess(tower_id, len(d.worker_devices)) + self.assertFalse(expected_devices[tower_id]) + expected_devices[tower_id] = True + + d.call_for_each_tower(mark_devices_fn) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_call_and_merge_exceptions(self, dist): + with dist.scope(): + with self.assertRaises(_TestException): + dist.call_for_each_tower(_raise_exception_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_merge_raises_fn) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py new file mode 100644 index 0000000000000000000000000000000000000000..87bf0590384cc74ca0f0575bcef4e84599a8b666 --- /dev/null +++ b/tensorflow/contrib/distribute/python/values.py @@ -0,0 +1,578 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing distributed values. + +See go/tf-distribution-strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import weakref + +import six + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import saver +from tensorflow.python.util import nest + + +# pylint: disable=line-too-long +# TODO(josh11b): Should device values be strings or DeviceSpec objects +# Not sure DeviceSpec objects are usable as a dict key. +class DistributedValues(object): + """Holds a map from device to values. Either PerDevice or Mirrored.""" + + def __init__(self, index): + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + + def get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribute_lib.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + device = device_util.current() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError: + raise ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())) + + def on_device(self, device): + device = device_util.canonicalize(device) + return device in self._index + + @property + def devices(self): + return list(self._index.keys()) + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + # TODO(josh11b): Possibly make an accessor for _index for use by + # DistributionStrategy implementations. + + +class DistributedDelegate(DistributedValues): + """A map from device to values; acts as the same type as the values.""" + + def __init__(self, index): + super(DistributedDelegate, self).__init__(index) + + def __getattr__(self, name): + return getattr(self.get(), name) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.get() + o + def __radd__(self, o): return o + self.get() + def __sub__(self, o): return self.get() - o + def __rsub__(self, o): return o - self.get() + def __mul__(self, o): return self.get() * o + def __rmul__(self, o): return o * self.get() + def __truediv__(self, o): return self.get() / o + def __rtruediv__(self, o): return o / self.get() + def __floordiv__(self, o): return self.get() // o + def __rfloordiv__(self, o): return o // self.get() + def __mod__(self, o): return self.get() % o + def __rmod__(self, o): return o % self.get() + def __lt__(self, o): return self.get() < o + def __le__(self, o): return self.get() <= o + def __gt__(self, o): return self.get() > o + def __ge__(self, o): return self.get() >= o + def __and__(self, o): return self.get() & o + def __rand__(self, o): return o & self.get() + def __or__(self, o): return self.get() | o + def __ror__(self, o): return o | self.get() + def __xor__(self, o): return self.get() ^ o + def __rxor__(self, o): return o ^ self.get() + def __getitem__(self, o): return self.get()[o] + def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) + def __rpow__(self, o): return pow(o, self.get()) + def __invert__(self): return ~self.get() + def __neg__(self): return -self.get() + def __abs__(self): return abs(self.get()) + + def __div__(self, o): + try: + return self.get().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.get().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.get().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.get().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # TODO(josh11b): Even more operator overloads. + + +class PerDevice(DistributedValues): + """Holds a map from device to unsynchronized values.""" + pass + + +class Mirrored(DistributedValues): + """Holds a map from device to values which are kept in sync.""" + pass + + +def _assign_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign(array_ops.identity(tensor)) + + +DistributedVarOp = collections.namedtuple( + "DistributedVarOp", ["name", "graph", "type"]) + + +class DistributedVariable(DistributedDelegate): + """Holds a map from device to variables.""" + # TODO(josh11b): Support changing the set of variables if e.g. if new + # devices are joining or a device is to leave. + + def __init__(self, index): + # Child class must set self._primary_var before calling + # super(...).__init__(index). + self._common_name = self._primary_var.name.split(":")[0] + super(DistributedVariable, self).__init__(index) + + @property + def initializer(self): + return control_flow_ops.group([v.initializer for v in self._index.values()]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + @property + def op(self): + # We want cross-tower code that does some var.op.X calls + # to work (even if the current device isn't in self.devices), but + # other uses of var.op in a cross-tower context to fail. + if distribute_lib.get_cross_tower_context(): + return DistributedVarOp(self._primary_var.op.name, + self._primary_var.op.graph, + self._primary_var.op.type) + return self.get().op + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) +ops.register_dense_tensor_like_type(DistributedVariable) + + +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + +def _get_update_device(): + """Validate we are in update/update_non_slot() and return current device. + + This is used in MirroredVariable.assign* members, to make sure they + are only called via an update method, to make sure all components of the + variable are being updated in a consistent way. + + Returns: + A string device. + + Raises: + RuntimeError: If not in distribution.update()/.update_non_slot(). + """ + device = distribute_lib.get_update_device() + if device is None: + raise RuntimeError( + "Use DistributionStrategy.update() to modify a MirroredVariable.") + return device + + +class MirroredVariable(DistributedVariable, Mirrored, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are kept in sync.""" + + def __init__(self, index, primary_var): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._primary_var = primary_var + super(MirroredVariable, self).__init__(index) + + # We use _get_update_device() for the assign* methods to enforce + # that we are in an update() function. The arguments to update() are + # automatically unwrapped so the update() function would normally + # see regular variables, not MirroredVariables. However, the update + # function can still operate on wrapped MirroredVariables through + # object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def assign_sub(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign(*args, **kwargs) + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): + """Class for defining how to restore a TowerLocalVariable.""" + + def __init__(self, tower_local_variable, name): + self._tower_local_variable = tower_local_variable + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + return distribute_lib.get_distribution_strategy().fetch( + tower_local_variable) + spec = saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=tower_local_variable.dtype) + super(_TowerLocalSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._tower_local_variable.reduce_method == "sum": + tensor *= 1. / len(self._tower_local_variable.devices) + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + + +class TowerLocalVariable(DistributedVariable, PerDevice, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are reduced on save.""" + + def __init__(self, index, primary_var, reduce_method): + self._primary_var = primary_var + self._reduce_method = reduce_method + super(TowerLocalVariable, self).__init__(index) + + def assign_sub(self, *args, **kwargs): + return self.get().assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get().assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get().assign(*args, **kwargs) + + @property + def reduce_method(self): + return self._reduce_method + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + TowerLocalVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _TowerLocalSaveable(self, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +def _devices_match(d1, d2): + return device_util.canonicalize(d1) == device_util.canonicalize(d2) + + +def regroup(per_device, wrap_class=PerDevice): + """Makes device->nest map into a nest of PerDevice/Mirrored values.""" + items = list(per_device.items()) + assert items + v0 = items[0][1] # First value + + if isinstance(v0, list): + for _, v in items[1:]: + assert isinstance(v, list) + assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % + (len(v), len(v0), v, v0)) + return [regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))] + + if isinstance(v0, tuple): + for _, v in items[1:]: + assert isinstance(v, tuple) + assert len(v) == len(v0) + regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))) + if hasattr(v0, "_fields"): + # This tuple is in fact a namedtuple! Create a new namedtuple instance + # and initialize it with the regrouped values: + assert hasattr(type(v0), "_make") + return type(v0)._make(regrouped_tuple) + else: + return regrouped_tuple + + if isinstance(v0, dict): + v0keys = set(v0.keys()) + for _, v in items[1:]: + assert isinstance(v, dict) + assert set(v.keys()) == v0keys + return {key: regroup({k: v[key] for k, v in items}, wrap_class) + for key in v0keys} + + # If exactly the same object across all devices, return it unwrapped. + same_id = True + for _, v in items[1:]: + if v is not v0: + same_id = False + break + # Consider three cases where same_id is true: + # * If v0 is a MirroredVariable (and same_id means it is the same + # across all devices), we want to return it. We check + # MirroredVariable specifically since it can look like it + # has a _mirrored_container member since its members do. + # * If v0 is a member of a mirrored variable, in which case + # hasattr(v0, "_mirrored_container") is true, we want to + # return the MirroredVariable that contains it using the + # _mirrored_container logic below. This case can trigger + # same_id when there is only one device. + # * In any other situation, same_id means we return v0. + if same_id and (isinstance(v0, MirroredVariable) or + not hasattr(v0, "_mirrored_container")): + return v0 + + # Detect the case where each device has a parallel component of the + # same MirroredVariable. In this case we want to return the + # containing MirroredVariable, after a bunch of sanity checking. + # In particular, each component should have the same container, + # and the devices of the variables should match the keys of the + # per-device dictionary. + # TODO(josh11b): Do we need similar logic for TowerLocalVariables? + if hasattr(v0, "_mirrored_container"): + # pylint: disable=protected-access + assert not isinstance(v0, MirroredVariable), ( + "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) + assert _devices_match(v0.device, items[0][0]), ( + "v0.device = %s, items = %s" % (v0.device, items)) + mirrored_container = v0._mirrored_container() + assert mirrored_container is not None + for d, v in items[1:]: + assert _devices_match(v.device, d), ( + "v.device = %s, d = %s, items = %s" % (v.device, d, items)) + assert mirrored_container is v._mirrored_container() + return mirrored_container + # pylint: enable=protected-access + + return wrap_class(per_device) + + +def select_device(device, structured): + """Specialize a nest of regular & per-device values for one device.""" + def _get(x): + return x.get(device) if isinstance(x, DistributedValues) else x + + return nest.map_structure(_get, structured) + + +def select_device_mirrored(device, structured): + """Specialize a nest of regular & mirrored values for one device.""" + def _get_mirrored(x): + if isinstance(x, DistributedValues): + if not isinstance(x, Mirrored): + raise TypeError( + "Expected value to be mirrored across towers: %s in %s." % + (x, structured)) + return x.get(device) + else: + return x + + return nest.map_structure(_get_mirrored, structured) + + +class PerDeviceDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" + + def __init__(self, iterator, devices, prefetch_on_device=None): + self._iterator = iterator + self._devices = devices + self._prefetch_on_device = prefetch_on_device + + def get_next(self, name=None): + """Scatter the input across devices.""" + if self._prefetch_on_device: + data_list = self._iterator.get_next(name=name) + index = dict(zip(self._devices, data_list)) + else: + batch = self._iterator.get_next(name=name) + index = {} + def get_ith(i): + return lambda x: x[i] + + for i, d in enumerate(self._devices): + index[d] = nest.map_structure(get_ith(i), batch) + if context.executing_eagerly(): + with ops.device(d): + index[d] = nest.map_structure(array_ops.identity, index[d]) + + return regroup(index) + + +class PerDeviceDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" + + def __init__(self, dataset, devices, prefetch_on_device=None): + self._devices = devices + + # Default to using prefetching in graph mode, unless specified. + # TODO(priyag): Enable prefetching in eager mode. + self._prefetch_on_device = prefetch_on_device + if self._prefetch_on_device is None: + self._prefetch_on_device = not context.executing_eagerly() + assert not (self._prefetch_on_device and context.executing_eagerly()), ( + "Prefetching is only supported in graph mode currently") + + if self._prefetch_on_device: + self._dataset = dataset + else: + # TODO(priyag): If dropping remainder is not appropriate, find another + # approach to distributing the dataset when not possible to divide evenly. + # Possibly not an issue when we start using PartitionedDataset. + self._dataset = dataset.apply( + batching.batch_and_drop_remainder(len(devices))) + + def make_one_shot_iterator(self): + """Get a one time use iterator for the distributed PerDeviceDataset.""" + if self._prefetch_on_device: + on_device_dataset = self._dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + dataset_iterator = on_device_dataset.make_one_shot_iterator() + elif context.executing_eagerly(): + dataset_iterator = datasets.Iterator(self._dataset) + else: + dataset_iterator = self._dataset.make_one_shot_iterator() + + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + +class MapOutput(object): + """Map can result in multiple outputs per device.""" + + def __init__(self, l): + self._l = l + + def get(self): + return self._l diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0d4b7d6c78b7cf63c613201d83d4793ecfe76b --- /dev/null +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -0,0 +1,807 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the distributed values library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import device_util +from tensorflow.python.training import saver as saver_lib + + +@test_util.with_c_api +class DistributedValuesTest(test.TestCase): + + def testGetEager(self): + with ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testGetGraph(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testCanonicalization(self): + canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] + v = values.DistributedValues({"": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/device:CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/cpu:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + with self.assertRaises(AssertionError): + v = values.DistributedValues({"/device:cpu:0": 42}) + + +@test_util.with_c_api +class DistributedDelegateTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testGetAttr(self): + with ops.device("/device:CPU:0"): + + class Foo(object): + + def __init__(self, x): + self.x = x + + v = values.DistributedDelegate( + {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + self.assertEqual(7, v.x) + with self.assertRaises(AttributeError): + _ = v.y + + @test_util.run_in_graph_and_eager_modes() + def testOperatorOverride(self): + with ops.device("/device:CPU:0"): + v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + # v should act like int(7). + self.assertEqual(8, v + 1) + self.assertEqual(10, 3 + v) + self.assertEqual(14, v + v) + self.assertEqual(5, v - 2) + self.assertEqual(6, 13 - v) + self.assertEqual(0, v - v) + self.assertEqual(14, v * 2) + self.assertEqual(21, 3 * v) + self.assertEqual(49, v * v) + self.assertEqual(3.5, v / 2) + self.assertEqual(1.5, 10.5 / v) + self.assertEqual(3, v // 2) + self.assertEqual(2, 15 // v) + self.assertEqual(1, v % 2) + self.assertEqual(2, 16 % v) + self.assertTrue(v < 12) + self.assertTrue(v <= 12) + self.assertFalse(v > 12) + self.assertFalse(v >= 12) + self.assertFalse(12 < v) + self.assertFalse(12 <= v) + self.assertTrue(12 > v) + self.assertTrue(12 >= v) + self.assertEqual(3, v & 3) + self.assertEqual(3, 11 & v) + self.assertEqual(15, v | 8) + self.assertEqual(23, 16 | v) + self.assertEqual(4, v ^ 3) + self.assertEqual(12, 11 ^ v) + self.assertEqual(343, pow(v, 3)) + self.assertEqual(3, pow(v, 3, 10)) + self.assertEqual(128, pow(2, v)) + self.assertEqual(-7, -v) + self.assertEqual(~7, ~v) + self.assertEqual(7, abs(v)) + with self.assertRaises(TypeError): + _ = v[2] + + +def _device_str(d): + return "/device:GPU:" + str(d) + + +def _nested_value(d): + return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) + + +def _make_mirrored(): + v = [] + index = {} + devices = ["/device:GPU:0", "/device:CPU:0"] + for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + mirrored = values.MirroredVariable(index, v[0]) + return v, devices, mirrored + + +@test_util.with_c_api +class RegroupAndSelectDeviceTest(test.TestCase): + + def _is_per_device(self, result, expected, klass=values.PerDevice): + self.assertIsInstance(result, klass) + # We canonicalize the devices to match the device strings returned + # by PerDevice, which also does device string canonicalization. + devices = [device_util.canonicalize(_device_str(i)) + for i in range(len(expected))] + self.assertEqual(set(devices), set(result.devices)) + for i, d in enumerate(devices): + self.assertEqual(expected[i], result.get(d)) + self.assertEqual(expected[i], result.get(_device_str(i))) + + def testNested(self): + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"]) + self._is_per_device(result[2], ["h1", "h2"]) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"]) + self._is_per_device(result[1][2], ["g1", "g2"]) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"]) + self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # select_device_mirrored() should fail due to non-mirrored values + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(0), result) + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(1), result) + + def testWrapClass(self): + # Normally a mirrored value would be the same across devices, but + # for a test it is convenient to be able to tell the values apart. + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}, + values.Mirrored) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # Values are marked as mirrored, so select_device_mirrored() is allowed. + self.assertEqual(_nested_value("1"), + values.select_device_mirrored(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device_mirrored(_device_str(1), result)) + + def testMirroredContainer(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + v, devices, mirrored = _make_mirrored() + result = values.regroup(dict(zip(devices, v))) + self.assertIs(mirrored, result) + + def testSameId(self): + foo = object() + result = values.regroup({_device_str(0): ("a", foo), + _device_str(1): ("b", foo)}) + self.assertIsInstance(result, tuple) + self.assertEqual(2, len(result)) + self._is_per_device(result[0], ["a", "b"]) + self.assertIs(foo, result[1]) + + # Test select_device(), should undo the merge done by regroup(). + result_0 = values.select_device(_device_str(0), result) + self.assertIsInstance(result_0, tuple) + self.assertEqual(2, len(result_0)) + self.assertEqual("a", result_0[0]) + self.assertIs(foo, result_0[1]) + result_1 = values.select_device(_device_str(1), result) + self.assertIsInstance(result_1, tuple) + self.assertEqual(2, len(result_1)) + self.assertEqual("b", result_1[0]) + self.assertIs(foo, result_1[1]) + + def testOneDevice(self): + result = values.regroup({_device_str(0): _nested_value("1")}) + # On one device regroup() and select_device() are basically identity. + self.assertEqual(_nested_value("1"), result) + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + + # The one exception has to do with MirroredVariables. + d = "/device:CPU:0" + with ops.device(d): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + index = {d: v} + mirrored = values.MirroredVariable(index, v) + result = values.regroup(index) + self.assertIs(mirrored, result) + + def testNamedTupleEstimatorSpec(self): + with context.graph_mode(), ops.Graph().as_default(): + created_estimator_specs = {} + to_regroup = {} + + for device_id in range(3): + spec = model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=constant_op.constant(device_id / 2), + train_op=array_ops.identity(constant_op.constant(device_id))) + created_estimator_specs[device_id] = spec + to_regroup[_device_str(device_id)] = spec + + merged_estimator_spec = values.regroup(to_regroup) + + self.assertTrue( + isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) + self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + for device_id in range(3): + d = _device_str(device_id) + self.assertEquals(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEquals(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) + # Scaffold is populated by `EstimatorSpec.__new__`. + self.assertEquals(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) + # Also test that we can undo the merge using select_device() + self.assertEquals(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) + + +@test_util.with_c_api +class PerDeviceDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator_no_prefetch(self, devices, dataset, expected_values): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_one_shot_iterator() + + for expected_value in expected_values: + next_element = iterator.get_next() + actual = self.evaluate([ + values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator_with_prefetch(self, devices, dataset, expected_values): + if not context.executing_eagerly(): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=True) + iterator = per_device_dataset.make_one_shot_iterator() + + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] + for expected_value in expected_values: + next_element = iterator.get_next() + combined_actual.extend(self.evaluate([ + values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator(self, devices, dataset, expected_values): + self._test_iterator_no_prefetch(devices, dataset, expected_values) + self._test_iterator_with_prefetch(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes() + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + +@test_util.with_c_api +class MirroredVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, _, mirrored = _make_mirrored() + + self.assertEquals(v[0].name, mirrored.name) + self.assertEquals(v[0].dtype, mirrored.dtype) + self.assertEquals(v[0].shape, mirrored.shape) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + mirrored = values.MirroredVariable(index, v) + + self.assertEquals(v.name, mirrored.name) + self.assertEquals(v.dtype, mirrored.dtype) + self.assertEquals(v.shape, mirrored.shape) + + def _assign_mirrored(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreMirroredOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path, saver = self._save_return_saver(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + + # Restores the saved value of 3. to both variables. + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + def _save_mirrored(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path = self._save(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.)) + + # Saves the current value of var, 3. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3. to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3., self.evaluate(var)) + + def _restore_mirrored(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [7., 8.]) + + # Restores the saved value of 3. to both variables. + saver = saver_lib.Saver(var_list=[mirrored]) + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_mirrored(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_mirrored(save_path) + + +_devices = ["/device:GPU:0", "/device:CPU:0"] + + +def _make_tower_local(method): + v = [] + index = {} + for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + tower_local = values.TowerLocalVariable(index, v[0], method) + return v, tower_local + + +@test_util.with_c_api +class TowerLocalVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, tower_local = _make_tower_local("sum") + + self.assertEquals(v[0].name, tower_local.name) + self.assertEquals(v[0].dtype, tower_local.dtype) + self.assertEquals(v[0].shape, tower_local.shape) + self.assertEquals("sum", tower_local.reduce_method) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + tower_local = values.TowerLocalVariable(index, v, "mean") + + self.assertEquals(v.name, tower_local.name) + self.assertEquals(v.dtype, tower_local.dtype) + self.assertEquals(v.shape, tower_local.shape) + self.assertEquals("mean", tower_local.reduce_method) + + def _assign_tower_local(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + def _dist_scope(self): + return mirrored_strategy.MirroredStrategy(_devices).scope() + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalSumOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 7. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 7. which gets divided equally + # between the variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalMeanOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 3.5 to both variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _save_tower_local_mean(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_tower_local_sum(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [1.5, 2.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.5)) + + # Saves the current value of var, 3.5. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3.5 to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3.5, self.evaluate(var)) + + def _restore_tower_local_mean(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _restore_tower_local_sum(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_tower_local_sum(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalMeanRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalSumRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_sum(save_path) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 1bd73ee7044de34988144196f53299db2fb80fcf..9799901483f1a8fa192b97b3d0f052e672c26843 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -457,6 +457,20 @@ cuda_py_test( tags = ["no_windows"], # TODO: needs investigation on Windows ) +cuda_py_test( + name = "batch_reshape_test", + size = "small", + srcs = ["python/kernel_tests/batch_reshape_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sample_stats_test", size = "medium", @@ -487,11 +501,7 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", ], - tags = [ - "manual", - "noasan", - "noguitar", - ], + shard_count = 4, ) cuda_py_test( @@ -745,18 +755,6 @@ cuda_py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # === Bijector Tests ========================================================== cuda_py_test( @@ -1106,25 +1104,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "sigmoid_centered_test", - size = "small", - srcs = ["python/kernel_tests/bijectors/sigmoid_centered_test.py"], - additional_deps = [ - ":bijectors_py", - ":distributions_py", - "//third_party/py/numpy", - "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - # Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to # avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 61c411271d0bb8d7b4cc3b14992b82ec1e5674ed..4d4489468d9dcfbe152c42f5f841f6c25a9f1e6f 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -24,6 +24,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.autoregressive import * +from tensorflow.contrib.distributions.python.ops.batch_reshape import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * @@ -96,9 +97,10 @@ _allowed_symbols = [ 'ReparameterizationType', 'Distribution', 'Autoregressive', - 'Binomial', + 'BatchReshape', 'Bernoulli', 'Beta', + 'Binomial', 'BetaWithSoftplusConcentration', 'Categorical', 'Chi2', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c8d2cf6e75f049248c6b16f429847889d141fa --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -0,0 +1,568 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchReshape.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class _BatchReshapeTest(object): + + def make_wishart(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale], axis=0), + old_batch_shape + [dims, dims]) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + wishart = wishart_lib.WishartFull(df=5, scale=scale_ph) + reshape_wishart = batch_reshape_lib.BatchReshape( + distribution=wishart, + batch_shape=new_batch_shape_ph, + validate_args=True) + + return wishart, reshape_wishart + + def test_matrix_variate_sample_and_log_prob(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_wishart.batch_shape_tensor() + event_shape = reshape_wishart.event_shape_tensor() + + expected_sample_shape = [3, 1] + new_batch_shape + [dims, dims] + x = wishart.sample([3, 1], seed=42) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_wishart.sample([3, 1], seed=42) + + expected_log_prob_shape = [3, 1] + new_batch_shape + expected_log_prob = array_ops.reshape( + wishart.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_wishart.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims, dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_wishart.batch_shape) + self.assertAllEqual([dims, dims], reshape_wishart.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_matrix_variate_stats(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_entropy = array_ops.reshape( + wishart.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_wishart.entropy() + + expected_mean = array_ops.reshape( + wishart.mean(), expected_matrix_stat_shape) + actual_mean = reshape_wishart.mean() + + expected_mode = array_ops.reshape( + wishart.mode(), expected_matrix_stat_shape) + actual_mode = reshape_wishart.mode() + + expected_stddev = array_ops.reshape( + wishart.stddev(), expected_matrix_stat_shape) + actual_stddev = reshape_wishart.stddev() + + expected_variance = array_ops.reshape( + wishart.variance(), expected_matrix_stat_shape) + actual_variance = reshape_wishart.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_variance.shape) + + def make_normal(self, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype(0.5 + np.arange( + np.prod(old_batch_shape)).reshape(old_batch_shape)) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph) + reshape_normal = batch_reshape_lib.BatchReshape( + distribution=normal, + batch_shape=new_batch_shape_ph, + validate_args=True) + return normal, reshape_normal + + def test_scalar_variate_sample_and_log_prob(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal( + new_batch_shape, old_batch_shape) + + batch_shape = reshape_normal.batch_shape_tensor() + event_shape = reshape_normal.event_shape_tensor() + + expected_sample_shape = new_batch_shape + x = normal.sample(seed=52) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_normal.sample(seed=52) + + expected_log_prob_shape = new_batch_shape + expected_log_prob = array_ops.reshape( + normal.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_normal.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_normal.batch_shape) + self.assertAllEqual([], reshape_normal.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_scalar_variate_stats(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal(new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + normal.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_normal.entropy() + + expected_mean = array_ops.reshape( + normal.mean(), expected_scalar_stat_shape) + actual_mean = reshape_normal.mean() + + expected_mode = array_ops.reshape( + normal.mode(), expected_scalar_stat_shape) + actual_mode = reshape_normal.mode() + + expected_stddev = array_ops.reshape( + normal.stddev(), expected_scalar_stat_shape) + actual_stddev = reshape_normal.stddev() + + expected_variance = array_ops.reshape( + normal.variance(), expected_scalar_stat_shape) + actual_variance = reshape_normal.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_variance.shape) + + def make_mvn(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + reshape_mvn = batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + return mvn, reshape_mvn + + def test_vector_variate_sample_and_log_prob(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_mvn.batch_shape_tensor() + event_shape = reshape_mvn.event_shape_tensor() + + expected_sample_shape = [3] + new_batch_shape + [dims] + x = mvn.sample(3, seed=62) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_mvn.sample(3, seed=62) + + expected_log_prob_shape = [3] + new_batch_shape + expected_log_prob = array_ops.reshape( + mvn.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_mvn.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_mvn.batch_shape) + self.assertAllEqual([dims], reshape_mvn.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_vector_variate_stats(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + mvn.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_mvn.entropy() + + expected_vector_stat_shape = new_batch_shape + [dims] + + expected_mean = array_ops.reshape( + mvn.mean(), expected_vector_stat_shape) + actual_mean = reshape_mvn.mean() + + expected_mode = array_ops.reshape( + mvn.mode(), expected_vector_stat_shape) + actual_mode = reshape_mvn.mode() + + expected_stddev = array_ops.reshape( + mvn.stddev(), expected_vector_stat_shape) + actual_stddev = reshape_mvn.stddev() + + expected_variance = array_ops.reshape( + mvn.variance(), expected_vector_stat_shape) + actual_variance = reshape_mvn.variance() + + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_covariance = array_ops.reshape( + mvn.covariance(), expected_matrix_stat_shape) + actual_covariance = reshape_mvn.covariance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + expected_covariance_, actual_covariance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + expected_covariance, actual_covariance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_covariance_, actual_covariance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_variance.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_covariance.shape) + + def test_bad_reshape_size(self): + dims = 2 + new_batch_shape = [2, 3] + old_batch_shape = [2] # 2 != 2*3 + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp( + ValueError, (r"`batch_shape` size \(6\) must match " + r"`distribution\.batch_shape` size \(2\)")): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r"`batch_shape` size must match " + r"`distributions.batch_shape` size"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_positive_shape(self): + dims = 2 + new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. + old_batch_shape = [2] + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be positive.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_vector_shape(self): + dims = 2 + new_batch_shape = 2 + old_batch_shape = [2] + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + + +class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float32 + is_static_shape = True + + +class BatchReshapeDynamicTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float64 + is_static_shape = False + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 20e754308449af3f0399101f4ea1bb47b3356424..a748acd667e58f9b527bab11d8bc4d086996e9f3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -66,12 +66,10 @@ class ChainBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): bijector = Chain([ - SoftmaxCentered( - event_ndims=1, validate_args=True), - SoftmaxCentered( - event_ndims=0, validate_args=True) + SoftmaxCentered(validate_args=True), + SoftmaxCentered(validate_args=True), ]) - x = tensor_shape.TensorShape([]) + x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 28e3e3135455348debb002b7d457e785799e1564..58ba9cedb1437df4e000ce32fe39664afa76c3b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -37,8 +37,7 @@ class InvertBijectorTest(test.TestCase): bijectors.Exp(event_ndims=1), bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]), bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), + bijectors.SoftmaxCentered(), ]: rev = bijectors.Invert(fwd) self.assertEqual("_".join(["invert", fwd.name]), rev.name) @@ -61,9 +60,9 @@ class InvertBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) + bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True)) x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) + y = tensor_shape.TensorShape([1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( y.as_list(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py index ad11d9f2484c4b08c67c5f82aec1320475d1d983..074b5f275d107fa49de42df262476bd4aa48ffae 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -69,7 +69,7 @@ class KumaraswamyBijectorTest(test.TestCase): bijector = Kumaraswamy( concentration1=concentration1, concentration0=concentration0, validate_args=True) - # Omitting the endpoints 0 and 1, since idlj will be inifinity at these + # Omitting the endpoints 0 and 1, since idlj will be infinity at these # endpoints. y = np.linspace(.01, 0.99, num=10).astype(np.float32) x = 1 - (1 - y ** concentration1) ** concentration0 diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py deleted file mode 100644 index 4ff3f334ccb59f1c117b3d35032d9e799cfd79bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import SigmoidCentered -from tensorflow.python.platform import test - - -class SigmoidCenteredBijectorTest(test.TestCase): - """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation.""" - - def testBijector(self): - with self.test_session(): - sigmoid = SigmoidCentered() - self.assertEqual("sigmoid_centered", sigmoid.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] - self.assertAllClose(y, sigmoid.forward(x).eval()) - self.assertAllClose(x, sigmoid.inverse(y).eval()) - self.assertAllClose( - -np.sum(np.log(y), axis=2), - sigmoid.inverse_log_det_jacobian(y).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - -sigmoid.inverse_log_det_jacobian(y).eval(), - sigmoid.forward_log_det_jacobian(x).eval(), - atol=0., - rtol=1e-7) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 4a7679daad6f6acc632eb9133078499dda89e43d..cad4dd1ac8de0da6405aacb9047714b37eec73e3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -34,34 +34,9 @@ rng = np.random.RandomState(42) class SoftmaxCenteredBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation.""" - def testBijectorScalar(self): - with self.test_session(): - softmax = SoftmaxCentered() # scalar by default - self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] - self.assertAllClose(y, softmax.forward(x).eval()) - self.assertAllClose(x, softmax.inverse(y).eval()) - self.assertAllClose( - -np.sum(np.log(y), axis=2), - softmax.inverse_log_det_jacobian(y).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), - atol=0., - rtol=1e-7) - def testBijectorVector(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) x = np.log([[2., 3, 4], [4., 8, 12]]) y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] @@ -80,7 +55,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): def testBijectorUnknownShape(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) real_x = np.log([[2., 3, 4], [4., 8, 12]]) @@ -106,24 +81,21 @@ class SoftmaxCenteredBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): - for x, y, b in ((tensor_shape.TensorShape([]), - tensor_shape.TensorShape([2]), - SoftmaxCentered( - event_ndims=0, validate_args=True)), - (tensor_shape.TensorShape([4]), - tensor_shape.TensorShape([5]), - SoftmaxCentered( - event_ndims=1, validate_args=True))): - self.assertAllEqual(y, b.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - b.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, b.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - b.inverse_event_shape_tensor(y.as_list()).eval()) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([5]) + bijector = SoftmaxCentered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + bijector.forward_event_shape_tensor( + x.as_list()).eval()) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + bijector.inverse_event_shape_tensor( + y.as_list()).eval()) def testBijectiveAndFinite(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32) # Make y values on the simplex with a wide range. y_0 = np.ones(5).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 507ceb35853ebe0a996d789b3bdf8a5f2284549c..68e0d9cb8277f3953039963fec0da499db7a16d1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -ds = distributions +tfd = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - ds.Normal, - ds.Bernoulli, - ds.Beta, - ds.Chi2, - ds.Exponential, - ds.Gamma, - ds.InverseGamma, - ds.Laplace, - ds.StudentT, - ds.Uniform, + tfd.Normal, + tfd.Bernoulli, + tfd.Beta, + tfd.Chi2, + tfd.Exponential, + tfd.Gamma, + tfd.InverseGamma, + tfd.Laplace, + tfd.StudentT, + tfd.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -63,15 +65,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], - validate_args=True) + wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -84,19 +86,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = ds.Normal(mu, sigma, validate_args=True) + normal = tfd.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) - normal = ds.Normal([mu], [sigma], validate_args=True) + normal = tfd.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) @@ -126,7 +128,7 @@ class DistributionTest(test.TestCase): self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) def _GetFakeDistribution(self): - class FakeDistribution(ds.Distribution): + class FakeDistribution(tfd.Distribution): """Fake Distribution for testing _set_sample_static_shape.""" def __init__(self, batch_shape=None, event_shape=None): @@ -188,6 +190,105 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testStrWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("tf.distributions.Normal(" + "\"Normal\", " + "batch_shape=(), " + "event_shape=(), " + "dtype=float16)"), # Got the dtype right. + str(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("tf.distributions.Chi2(" + "\"silly\", " # What a silly name that is! + "batch_shape=(2,), " + "event_shape=(), " + "dtype=float32)"), + str(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("tf.distributions.Exponential(\"Exponential\", " + # No batch shape. + "event_shape=(), " + "dtype=float32)"), + str(exp)) + + def testStrWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN\", " + "batch_shape=(2,), " + "event_shape=(2,), " + "dtype=float64)"), + str(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN2\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + + def testReprWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + (""), # Got the dtype right. + repr(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + (""), + repr(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("" + " event_shape=()" + " dtype=float32>"), + repr(exp)) + + def testReprWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + (""), + repr(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + (""), + repr(mvn_dynamic)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 4186cf129dbf31724c84133734da3f226817c71a..ea04e8c29a2c94d4939bad277afa380401067ff2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test @@ -455,6 +456,16 @@ class PercentileTestWithNearestInterpolation(test.TestCase): with self.assertRaisesOpError("rank"): pct.eval(feed_dict={q_ph: [0.5]}) + def test_finds_max_of_long_array(self): + # d - 1 == d in float32 and d = 3e7. + # So this test only passes if we use double for the percentile indices. + # If float is used, it fails with InvalidArgumentError about an index out of + # bounds. + x = math_ops.linspace(0., 3e7, num=int(3e7)) + with self.test_session(): + minval = sample_stats.percentile(x, q=0, validate_args=True) + self.assertAllEqual(0, minval.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index 3548ac18078a0b40f117c2bf9e2b34d20cee163b..0400c80c29cf0c36090168b7a1a6358ad49fde49 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -22,39 +22,75 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import statistical_testing as st from tensorflow.python.framework import errors -from tensorflow.python.ops import check_ops from tensorflow.python.platform import test class StatisticalTestingTest(test.TestCase): def test_dkwm_design_mean_one_sample_soundness(self): - numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] - with self.test_session() as sess: - for ff in rates: - for fp in rates: - sufficient_n = st.min_num_samples_for_dkwm_mean_test( - numbers, 0., 1., false_fail_rate=ff, false_pass_rate=fp) - detectable_d = st.min_discrepancy_of_true_means_detectable_by_dkwm( - sufficient_n, 0., 1., false_fail_rate=ff, false_pass_rate=fp) - sess.run(check_ops.assert_less_equal(detectable_d, numbers)) + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + sufficient_n = st.min_num_samples_for_dkwm_mean_test( + thresholds, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + sufficient_n, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) def test_dkwm_design_mean_two_sample_soundness(self): - numbers = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] - with self.test_session() as sess: - for ff in rates: - for fp in rates: - (sufficient_n1, - sufficient_n2) = st.min_num_samples_for_dkwm_mean_two_sample_test( - numbers, 0., 1., 0., 1., - false_fail_rate=ff, false_pass_rate=fp) - d_fn = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample - detectable_d = d_fn( - sufficient_n1, 0., 1., sufficient_n2, 0., 1., - false_fail_rate=ff, false_pass_rate=fp) - sess.run(check_ops.assert_less_equal(detectable_d, numbers)) + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + [ + sufficient_n1, + sufficient_n2 + ] = st.min_num_samples_for_dkwm_mean_two_sample_test( + thresholds, low1=0., high1=1., low2=0., high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1=sufficient_n1, + low1=0., + high1=1., + n2=sufficient_n2, + low2=0., + high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) def test_true_mean_confidence_interval_by_dkwm_one_sample(self): rng = np.random.RandomState(seed=0) @@ -105,16 +141,16 @@ class StatisticalTestingTest(test.TestCase): def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) - num_samples = 15000 + num_samples = 4000 - # 15000 samples is chosen to be enough to find discrepancies of - # size 0.1 or more with assurance 1e-6, as confirmed here: + # 4000 samples is chosen to be enough to find discrepancies of + # size 0.2 or more with assurance 1e-6, as confirmed here: with self.test_session() as sess: d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( num_samples, 0., 1., num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) d = sess.run(d) - self.assertLess(d, 0.1) + self.assertLess(d, 0.2) # Test that the test assertion agrees that the standard # uniform distribution has the same mean as itself. @@ -124,6 +160,15 @@ class StatisticalTestingTest(test.TestCase): sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + with self.test_session() as sess: # Test that the test assertion confirms that the mean of the # standard uniform distribution is different from the mean of beta(2, 1). beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) @@ -133,6 +178,15 @@ class StatisticalTestingTest(test.TestCase): beta_high_samples, 0., 1., false_fail_rate=1e-6)) + def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + with self.test_session() as sess: # Test that the test assertion confirms that the mean of the # standard uniform distribution is different from the mean of beta(1, 2). beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index af13553c32bdb6ef4038daa5e4bbef3251cff2f3..f0ba1ec3eb57c67c1a0edb15639e91916a4509b7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -186,12 +186,14 @@ class TransformedDistributionTest(test.TestCase): standard_normal = ds.Normal(loc=0., scale=1.) multi_logit_normal = self._cls()( distribution=standard_normal, - bijector=softmax) - x = [[-np.log(3.), 0.], - [np.log(3), np.log(5)]] + bijector=softmax, + event_shape=[1]) + x = [[[-np.log(3.)], [0.]], + [[np.log(3)], [np.log(5)]]] y = softmax.forward(x).eval() - expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) - - np.sum(np.log(y), axis=-1)) + expected_log_pdf = ( + np.squeeze(stats.norm(loc=0., scale=1.).logpdf(x)) - + np.sum(np.log(y), axis=-1)) self.assertAllClose(expected_log_pdf, multi_logit_normal.log_prob(y).eval()) self.assertAllClose( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 9044aa2850ae35f29cd48b0c5f54aa948bea0408..dcecce981f16a2d9e772d4e40062ff250725c3ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -390,6 +390,26 @@ class WishartCholeskyTest(test.TestCase): chol_scale, dtype=np.int32), validate_args=False) + def testSampleBroadcasts(self): + dims = 2 + batch_shape = [2, 3] + sample_shape = [2, 1] + scale = np.float32([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), + batch_shape + [dims, dims]) + wishart = distributions.WishartFull(df=5, scale=scale) + x = wishart.sample(sample_shape, seed=42) + with self.test_session() as sess: + x_ = sess.run(x) + expected_shape = sample_shape + batch_shape + [dims, dims] + self.assertAllEqual(expected_shape, x.shape) + self.assertAllEqual(expected_shape, x_.shape) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 852298bf334666db003353d5fc8e172ffb738668..69f3d57ff000d6c9acc8aa9e3d0ad8d9cbb6bb3c 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -36,7 +36,8 @@ class Autoregressive(distribution_lib.Distribution): "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." [(Papamakarios et + al., 2016)][1] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -45,17 +46,18 @@ class Autoregressive(distribution_lib.Distribution): Practically speaking the autoregressive property means that there exists a permutation of the event coordinates such that each coordinate is a - diffeomorphic function of only preceding coordinates. [2] + diffeomorphic function of only preceding coordinates + [(van den Oord et al., 2016)][2]. #### Mathematical Details - The probability function is, + The probability function is ```none prob(x; fn, n) = fn(x).prob(x) ``` - And a sample is generated by, + And a sample is generated by ```none x = fn(...fn(fn(x0).sample()).sample()).sample() @@ -93,13 +95,15 @@ class Autoregressive(distribution_lib.Distribution): ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "Conditional Image Generation with PixelCNN Decoders." - Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex - Graves, Koray Kavukcuoglu. Arxiv, 2016. + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + + [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, + Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with + PixelCNN Decoders. In _Neural Information Processing Systems_, 2016. https://arxiv.org/abs/1606.05328 """ diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6c35e0d6076113839481678abd3c20f8fb5db9 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -0,0 +1,415 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The BatchReshape distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib + + +__all__ = [ + "BatchReshape", +] + + +class BatchReshape(distribution_lib.Distribution): + """The Batch-Reshaping distribution. + + This "meta-distribution" reshapes the batch dimensions of another + distribution. + + Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support + `-1` for flattening. + + #### Examples + + ```python + tfd = tf.contrib.distributions + + dtype = np.float32 + dims = 2 + new_batch_shape = [1, 2, 3] + old_batch_shape = [6] + + scale = np.ones(old_batch_shape + [dims], dtype) + mvn = tfd.MultivariateNormalDiag(scale_diag=scale) + reshape_mvn = tfd.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape, + validate_args=True) + + reshape_mvn.batch_shape + # ==> [1, 2, 3] + + x = reshape_mvn.sample(sample_shape=[4, 5]) + x.shape + # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] + + reshape_mvn.log_prob(x).shape + # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape + ``` + + """ + + def __init__(self, + distribution, + batch_shape, + validate_args=False, + allow_nan_stats=True, + name=None): + """Construct BatchReshape distribution. + + Args: + distribution: The base distribution instance to reshape. Typically an + instance of `Distribution`. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing the + new shape of the batch dimensions. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: The name to give Ops created by the initializer. + Default value: `"BatchReshape" + distribution.name`. + + Raises: + ValueError: if `batch_shape` is not a vector. + ValueError: if `batch_shape` has non-positive elements. + ValueError: if `batch_shape` size is not the same as a + `distribution.batch_shape` size. + """ + parameters = locals() + name = name or "BatchReshape" + distribution.name + self._distribution = distribution + with ops.name_scope(name, values=[batch_shape]) as name: + self._batch_shape_ = ops.convert_to_tensor( + batch_shape, + dtype=dtypes.int32, + name="batch_shape") + self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) + if self._batch_shape_static is not None: + self._batch_shape_static = np.int32(self._batch_shape_static) + self._runtime_assertions = validate_init_args( + self._distribution, + self._batch_shape_, + validate_args, + self._batch_shape_static) + super(BatchReshape, self).__init__( + dtype=self._distribution.dtype, + reparameterization_type=self._distribution.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=( + [self._batch_shape_] + + self._distribution._graph_parents), # pylint: disable=protected-access + name=name) + + @property + def distribution(self): + return self._distribution + + def _batch_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self._batch_shape_) + + def _batch_shape(self): + return tensor_shape.TensorShape(self._batch_shape_static) + + def _event_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self.distribution.event_shape_tensor()) + + def _event_shape(self): + return self.distribution.event_shape + + def _sample_n(self, n, seed=None): + with ops.control_dependencies(self._runtime_assertions): + x = self.distribution.sample(sample_shape=n, seed=seed) + new_shape = array_ops.concat([ + [n], + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + return array_ops.reshape(x, new_shape) + + def _log_prob(self, x): + return self._call_reshape_input_output( + self.distribution.log_prob, x) + + def _prob(self, x): + return self._call_reshape_input_output( + self.distribution.prob, x) + + def _log_cdf(self, x): + return self._call_reshape_input_output( + self.distribution.log_cdf, x) + + def _cdf(self, x): + return self._call_reshape_input_output( + self.distribution.cdf, x) + + def _log_survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.log_survival_function, x) + + def _survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.survival_function, x) + + def _entropy(self): + return self._call_and_reshape_output( + self.distribution.entropy, + [], + [tensor_shape.scalar()]) + + def _mean(self): + return self._call_and_reshape_output(self.distribution.mean) + + def _mode(self): + return self._call_and_reshape_output(self.distribution.mode) + + def _stddev(self): + return self._call_and_reshape_output(self.distribution.stddev) + + def _variance(self): + return self._call_and_reshape_output(self.distribution.variance) + + def _covariance(self): + return self._call_and_reshape_output( + self.distribution.covariance, + [self.event_shape_tensor()]*2, + [self.event_shape]*2) + + def _sample_shape(self, x): + """Computes graph and static `sample_shape`.""" + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + sample_ndims = x_ndims - batch_ndims - event_ndims + if isinstance(sample_ndims, int): + static_sample_shape = x.shape[:sample_ndims] + else: + static_sample_shape = tensor_shape.TensorShape(None) + if static_sample_shape.is_fully_defined(): + sample_shape = np.int32(static_sample_shape.as_list()) + else: + sample_shape = array_ops.shape(x)[:sample_ndims] + return sample_shape, static_sample_shape + + def _call_reshape_input_output(self, fn, x): + """Calls `fn`, appropriately reshaping its input `x` and output.""" + with ops.control_dependencies( + self._runtime_assertions + self._validate_sample_arg(x)): + sample_shape, static_sample_shape = self._sample_shape(x) + old_shape = array_ops.concat([ + sample_shape, + self.distribution.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + result = fn(array_ops.reshape(x, old_shape)) + new_shape = array_ops.concat([ + sample_shape, + self.batch_shape_tensor(), + ], axis=0) + result = array_ops.reshape(result, new_shape) + if (static_sample_shape.ndims is not None and + self.batch_shape.ndims is not None): + new_shape = static_sample_shape.concatenate(self.batch_shape) + result.set_shape(result.shape.merge_with(new_shape)) + return result + + def _call_and_reshape_output( + self, + fn, + event_shape_list=None, + static_event_shape_list=None): + """Calls `fn` and appropriately reshapes its output.""" + with ops.control_dependencies(self._runtime_assertions): + if event_shape_list is None: + event_shape_list = [self._event_shape_tensor()] + if static_event_shape_list is None: + static_event_shape_list = [self.event_shape] + new_shape = array_ops.concat( + [self.batch_shape_tensor()] + event_shape_list, + axis=0) + result = array_ops.reshape(fn(), new_shape) + if (self.batch_shape.ndims is not None and + self.event_shape.ndims is not None): + event_shape = tensor_shape.TensorShape([]) + for rss in static_event_shape_list: + event_shape = event_shape.concatenate(rss) + static_shape = result.shape.merge_with( + self.batch_shape.concatenate(event_shape)) + result.set_shape(static_shape) + return result + + def _validate_sample_arg(self, x): + """Helper which validates sample arg, e.g., input to `log_prob`.""" + with ops.name_scope(name="validate_sample_arg", values=[x]): + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + expected_batch_event_ndims = batch_ndims + event_ndims + + if (isinstance(x_ndims, int) and + isinstance(expected_batch_event_ndims, int)): + if x_ndims < expected_batch_event_ndims: + raise NotImplementedError( + "Broadcasting is not supported; too few event dims " + "(expected at least {}, saw {}).".format( + expected_batch_event_ndims, x_ndims)) + ndims_assertion = [] + elif self.validate_args: + ndims_assertion = [ + check_ops.assert_greater_equal( + x_ndims, + expected_batch_event_ndims, + message="Broadcasting is not supported; too few event dims.", + name="assert_batch_and_event_ndims_large_enough"), + ] + + if (self.batch_shape.is_fully_defined() and + self.event_shape.is_fully_defined()): + expected_batch_event_shape = np.int32(self.batch_shape.concatenate( + self.event_shape).as_list()) + else: + expected_batch_event_shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + + sample_ndims = x_ndims - expected_batch_event_ndims + if isinstance(sample_ndims, int): + sample_ndims = max(sample_ndims, 0) + if (isinstance(sample_ndims, int) and + x.shape[sample_ndims:].is_fully_defined()): + actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) + else: + sample_ndims = math_ops.maximum(sample_ndims, 0) + actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] + + if (isinstance(expected_batch_event_shape, np.ndarray) and + isinstance(actual_batch_event_shape, np.ndarray)): + if any(expected_batch_event_shape != actual_batch_event_shape): + raise NotImplementedError("Broadcasting is not supported; " + "unexpected batch and event shape " + "(expected {}, saw {}).".format( + expected_batch_event_shape, + actual_batch_event_shape)) + # We need to set the final runtime-assertions to `ndims_assertion` since + # its possible this assertion was created. We could add a condition to + # only do so if `self.validate_args == True`, however this is redundant + # as `ndims_assertion` already encodes this information. + runtime_assertions = ndims_assertion + elif self.validate_args: + # We need to make the `ndims_assertion` a control dep because otherwise + # TF itself might raise an exception owing to this assertion being + # ill-defined, ie, one cannot even compare different rank Tensors. + with ops.control_dependencies(ndims_assertion): + shape_assertion = check_ops.assert_equal( + expected_batch_event_shape, + actual_batch_event_shape, + message=("Broadcasting is not supported; " + "unexpected batch and event shape."), + name="assert_batch_and_event_shape_same") + runtime_assertions = [shape_assertion] + else: + runtime_assertions = [] + + return runtime_assertions + + +def validate_init_args( + distribution, + batch_shape, + validate_args, + batch_shape_static): + """Helper to __init__ which makes or raises assertions.""" + with ops.name_scope(name="validate_init_args", + values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access + runtime_assertions = [] + + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format( + batch_shape.shape.ndims)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_rank( + batch_shape, + 1, + message="`batch_shape` must be a vector.", + name="assert_batch_shape_is_vector"), + ] + + batch_size_static = np.prod(batch_shape_static) + dist_batch_size_static = ( + None if not distribution.batch_shape.is_fully_defined() + else np.prod(distribution.batch_shape).value) + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, + dist_batch_size_static)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_equal( + math_ops.reduce_prod(batch_shape), + math_ops.reduce_prod(distribution.batch_shape_tensor()), + message=("`batch_shape` size must match " + "`distributions.batch_shape` size."), + name="assert_batch_size"), + ] + + if batch_shape_static is not None: + if np.any(batch_shape_static < 1): + raise ValueError("`batch_shape` elements must be positive " + "(i.e., larger than zero).") + elif validate_args: + runtime_assertions += [ + check_ops.assert_positive( + batch_shape, + message=("`batch_shape` elements must be positive " + "(i.e., larger than zero)."), + name="assert_batch_shape_positive") + ] + + return runtime_assertions diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 452f1caa30fdbf5442274cbcc7f3549081b80ae9..bc6b02542ebf3b83d58f888509dafb86351de8a7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -35,7 +35,6 @@ @@RealNVP @@Reshape @@Sigmoid -@@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus @@ -72,7 +71,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.power_transform impor from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 7fe73ada4466d38a7d352f23a55d6b90ed38c84a..bef7bbb49b715497695f7513e19ecab4fa56c47e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -62,7 +62,7 @@ class Affine(bijector.Bijector): matrices, i.e., the matmul is [matrix-free]( https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - Examples: + #### Examples ```python # Y = X diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index be72ff3081225b9f9fdb6541322b7fc3d4aaa41e..33fdd32d7a0a01685690e598c69adca2c95972e9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -76,15 +76,16 @@ def _undo_batch_normalization(x, class BatchNormalization(bijector.Bijector): """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. - Applies Batch Normalization [1] to samples from a data distribution. This can - be used to stabilize training of normalizing flows [2, 3]. + Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a + data distribution. This can be used to stabilize training of normalizing + flows ([Papamakarios et al., 2016][3]; [Dinh et al., 2017][2]) When training Deep Neural Networks (DNNs), it is common practice to normalize or whiten features by shifting them to have zero mean and scaling them to have unit variance. - The `inverse()` method of the BatchNorm bijector, which is used in the - log-likelihood computation of data samples, implements the normalization + The `inverse()` method of the `BatchNormalization` bijector, which is used in + the log-likelihood computation of data samples, implements the normalization procedure (shift-and-scale) using the mean and standard deviation of the current minibatch. @@ -92,7 +93,6 @@ class BatchNormalization(bijector.Bijector): `X*std(Y) + mean(Y)` with the running-average mean and standard deviation computed at training-time. De-normalization is useful for sampling. - ```python dist = tfd.TransformedDistribution( @@ -112,19 +112,20 @@ class BatchNormalization(bijector.Bijector): `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when `training=False` but may be different when `training=True`. - [1]: "Batch Normalization: Accelerating Deep Network Training by Reducing - Internal Covariate Shift." - Sergey Ioffe, Christian Szegedy. Arxiv. 2015. - https://arxiv.org/abs/1502.03167 + #### References - [2]: "Density Estimation using Real NVP." - Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. - https://arxiv.org/abs/1605.08803 + [1]: Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating + Deep Network Training by Reducing Internal Covariate Shift. In + _International Conference on Machine Learning_, 2015. + https://arxiv.org/abs/1502.03167 - [3]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + [2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 43208ff088b469b70ebc08757daac277d4432b37..8f09e16058b766c788ab3acced6940fd0026b521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -57,7 +57,7 @@ class CholeskyOuterProduct(bijector.Bijector): that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - Examples: + #### Examples ```python bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]]) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 5251dbcb5748f75688aa43ce6e4e9dbd76be78bb..84b2340c75514c3d2c12bf4d775ba74450a0dc26 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -45,14 +45,15 @@ __all__ = [ class MaskedAutoregressiveFlow(bijector_lib.Bijector): """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, + The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a + relatively simple framework for user-specified (deep) architectures to learn + a distribution over vector-valued events. Regarding terminology, "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." + [(Papamakarios et al., 2016)][3] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -75,26 +76,26 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): Given a `shift_and_log_scale_fn`, the forward and inverse transformations are (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. + must compute each `shift` (aka `loc` or "mu" in [Germain et al. (2015)][1]) + and `log(scale)` (aka "alpha" in [Germain et al. (2015)][1]) such that each + are broadcastable with the arguments to `forward` and `inverse`, i.e., such + that the calculations in `forward`, `inverse` [below] are possible. For convenience, `masked_autoregressive_default_template` is offered as a possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. + architecture [(Germain et al., 2015)][1]. MADE is a feed-forward network that + computes a `shift` and `log(scale)` using `masked_dense` layers in a deep + neural network. Weights are masked to ensure the autoregressive property. It + is possible that this architecture is suboptimal for your task. To build + alternative networks, either change the arguments to + `masked_autoregressive_default_template`, use the `masked_dense` function to + roll-out your own, or use some other architecture, e.g., using `tf.layers`. Warning: no attempt is made to validate that the `shift_and_log_scale_fn` enforces the "autoregressive property". Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, + semantics, the forward transformation is ```python def forward(x): @@ -106,7 +107,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y ``` - and the inverse transformation is, + and the inverse transformation is ```python def inverse(y): @@ -121,7 +122,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this also proves the transform is bijective.) - #### Example Use + #### Examples ```python tfd = tf.contrib.distributions @@ -142,7 +143,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): maf.log_prob(x) # Almost free; uses Bijector caching. maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - # [1] also describes an "Inverse Autoregressive Flow", e.g., + # [Papamakarios et al. (2016)][3] also describe an Inverse Autoregressive + # Flow [(Kingma et al., 2016)][2]: iaf = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( @@ -168,14 +170,20 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): event_shape=[dims]) ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 + [2]: Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya + Sutskever, and Max Welling. Improving Variational Inference with Inverse + Autoregressive Flow. In _Neural Information Processing Systems_, 2016. + https://arxiv.org/abs/1606.04934 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -329,11 +337,7 @@ def masked_dense(inputs, **kwargs): """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + See [Germain et al. (2015)][1] for detailed explanation. Arguments: inputs: Tensor input. @@ -358,6 +362,12 @@ def masked_dense(inputs, Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ # TODO(b/67594795): Better support of dynamic shape. input_depth = inputs.shape.with_rank_at_least(1)[-1].value @@ -398,23 +408,24 @@ def masked_autoregressive_default_template( name=None, *args, **kwargs): - """Build the MADE Model [1]. + """Build the Masked Autoregressive Density Estimator (Germain et al., 2015). This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. + created once. It takes the input and returns the `loc` ("mu" in [Germain et + al. (2015)][1]) and `log_scale` ("alpha" in [Germain et al. (2015)][1]) from + the MADE network. Warning: This function uses `masked_dense` to create randomly initialized `tf.Variables`. It is presumed that these will be fit, just as you would any other neural architecture which uses `tf.layers.dense`. - #### About Hidden Layers: + #### About Hidden Layers Each element of `hidden_layers` should be greater than the `input_depth` (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the neural network). This is necessary to ensure the autoregressivity property. - #### About Clipping: + #### About Clipping This function also optionally clips the `log_scale` (but possibly not its gradient). This is useful because if `log_scale` is too small/large it might @@ -427,11 +438,7 @@ def masked_autoregressive_default_template( `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual `grad[clip(x)] exp(clip(x))`. - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]. shift_only: Python `bool` indicating if only the `shift` term shall be @@ -450,12 +457,20 @@ def masked_autoregressive_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms (the "mu" in + [Germain et al. (2015)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in + [Germain et al. (2015)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ with ops.name_scope(name, "masked_autoregressive_default_template", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 2840f52e742eac5e9e37a576bf7f6d6f05a07a35..71ab369d01aafc33854a2c2437f96bbb493cc6fb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -38,7 +38,7 @@ class RealNVP(bijector_lib.Bijector): """RealNVP "affine coupling layer" for vector-valued events. Real NVP models a normalizing flow on a `D`-dimensional distribution via a - single `D-d`-dimensional conditional distribution [1]: + single `D-d`-dimensional conditional distribution [(Dinh et al., 2017)][1]: `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` `y[0:d] = x[0:d]` @@ -51,31 +51,34 @@ class RealNVP(bijector_lib.Bijector): Masking is currently only supported for base distributions with `event_ndims=1`. For more sophisticated masking schemes like checkerboard or - channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired - masked units into the first `d` units. For base distributions with - `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. - - Recall that the MAF bijector [2] implements a normalizing flow via an - autoregressive transformation. MAF and IAF have opposite computational - tradeoffs - MAF can train all units in parallel but must sample units - sequentially, while IAF must train units sequentially but can sample in - parallel. In contrast, Real NVP can compute both forward and inverse - computations in parallel. However, the lack of an autoregressive + channel-wise masking [(Papamakarios et al., 2016)[4], use the `tfb.Permute` + bijector to re-order desired masked units into the first `d` units. For base + distributions with `event_ndims > 1`, use the `tfb.Reshape` bijector to + flatten the event shape. + + Recall that the MAF bijector [(Papamakarios et al., 2016)][4] implements a + normalizing flow via an autoregressive transformation. MAF and IAF have + opposite computational tradeoffs - MAF can train all units in parallel but + must sample units sequentially, while IAF must train units sequentially but + can sample in parallel. In contrast, Real NVP can compute both forward and + inverse computations in parallel. However, the lack of an autoregressive transformations makes it less expressive on a per-bijector basis. A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or - "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable - with the arguments to `forward` and `inverse`, i.e., such that the - calculations in `forward`, `inverse` [below] are possible. For convenience, + "mu" in [Papamakarios et al. (2016)][4]) and `log(scale)` (aka "alpha" in + [Papamakarios et al. (2016)][4]) such that each are broadcastable with the + arguments to `forward` and `inverse`, i.e., such that the calculations in + `forward`, `inverse` [below] are possible. For convenience, `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` function. - NICE [3] is a special case of the Real NVP bijector which discards the scale - transformation, resulting in a constant-time inverse-log-determinant-Jacobian. - To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should - return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in - the `RealNVP` constructor. Calling `real_nvp_default_template` with - `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + NICE [(Dinh et al., 2014)][2] is a special case of the Real NVP bijector + which discards the scale transformation, resulting in a constant-time + inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real + NVP, `shift_and_log_scale_fn` should return `(shift, None)`, and + `is_constant_jacobian` should be set to `True` in the `RealNVP` constructor. + Calling `real_nvp_default_template` with `shift_only=True` returns one such + NICE-compatible `shift_and_log_scale_fn`. Caching: the scalar input depth `D` of the base distribution is not known at construction time. The first call to any of `forward(x)`, `inverse(x)`, @@ -103,23 +106,24 @@ class RealNVP(bijector_lib.Bijector): nvp.log_prob(0.) ``` - For more examples, see [4]. + For more examples, see [Jang (2018)][3]. - [1]: "Density Estimation using Real NVP." - Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. - https://arxiv.org/abs/1605.08803 + #### References - [2]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + [1]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 - [3]: "NICE: Non-linear Independent Components Estimation." - Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. - https://arxiv.org/abs/1410.8516 + [2]: Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear + Independent Components Estimation. _arXiv preprint arXiv:1410.8516_, + 2014. https://arxiv.org/abs/1410.8516 - [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." - Eric Jang. Blog post. January 2018. - http://blog.evjang.com/2018/01/nf2.html + [3]: Eric Jang. Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows. + _Technical Report_, 2018. http://blog.evjang.com/2018/01/nf2.html + + [4]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -250,12 +254,20 @@ def real_nvp_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms ("mu" in + [Papamakarios et al. (2016)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms ("alpha" in + [Papamakarios et al. (2016)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ with ops.name_scope(name, "real_nvp_default_template"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py deleted file mode 100644 index 223bc9d042c69be05b0e578835a31ed6e83c0c97..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SigmoidCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered - - -__all__ = [ - "SigmoidCentered", -] - - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 24add40445c60db533aac6d0c8eb537774895c65..dc94fd0a38de29f5a7ee6ca826aab0ecf8712966 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -19,10 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -45,17 +42,14 @@ class SoftmaxCentered(bijector.Bijector): e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last coordinate. - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - Example Use: ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + bijector.SoftmaxCentered().forward(tf.log([2, 3, 4])) # Result: [0.2, 0.3, 0.4, 0.1] # Extra result: 0.1 - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1]) # Result: tf.log([2, 3, 4]) # Extra coordinate removed. ``` @@ -67,82 +61,47 @@ class SoftmaxCentered(bijector.Bijector): """ def __init__(self, - event_ndims=0, validate_args=False, name="softmax_centered"): self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args, name=name) def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: + if input_shape.ndims is None or input_shape[-1] is None: return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + return tensor_shape.TensorShape([input_shape[-1] + 1]) def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 + return (input_shape[-1] + 1)[..., array_ops.newaxis] def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: + if output_shape.ndims is None or output_shape[-1] is None: return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1] - 1]) def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] if self.validate_args: # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1] - 1)[..., array_ops.newaxis] def _forward(self, x): # Pad the last dim with a zeros vector. We need this because it lets us # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - y = distribution_util.pad(y, axis=-1, back=True) + y = distribution_util.pad(x, axis=-1, back=True) # Set shape hints. if x.shape.ndims is not None: - shape = x.shape.as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) + shape = x.shape[:-1].concatenate(x.shape[-1] + 1) y.shape.assert_is_compatible_with(shape) y.set_shape(shape) @@ -167,17 +126,9 @@ class SoftmaxCentered(bijector.Bijector): log_normalization = (-x[..., -1])[..., array_ops.newaxis] x = x[..., :-1] + log_normalization - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=-1) - # Set shape hints. if y.shape.ndims is not None: - shape = y.shape.as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) + shape = y.shape[:-1].concatenate(y.shape[-1] - 1) x.shape.assert_is_compatible_with(shape) x.set_shape(shape) @@ -203,19 +154,16 @@ class SoftmaxCentered(bijector.Bijector): return -math_ops.reduce_sum(math_ops.log(y), axis=-1) def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 2831a92df8e0ad2bf681f13533cdb6f5d2089a57..1e9dbf35091fe51f2478dc085c394a77295ca4ee 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -37,7 +37,7 @@ class Square(bijector.Bijector): g is a bijection between the non-negative real numbers (R_+) and the non-negative real numbers. - Examples: + #### Examples ```python bijector.Square().forward(x=[[1., 0], [2, 1]]) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 6b53338c4542c75d3977c075b7750c780080ac48..98edd337fe02ffbf53c6ecd9ebda9424231ea2fe 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -75,7 +75,7 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): - """Creates a _RegressionHead instance from an arbitray `Distribution`.""" + """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" def __init__(self, make_distribution_fn, diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 7dcb3e3ac4db1855adacb7ec0fa8554c45d9c859..b1bacb91b03093fa93a7e5f7eb855dc944dafb44 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -36,7 +36,7 @@ class Independent(distribution_lib.Distribution): This distribution is useful for regarding a collection of independent, non-identical distributions as a single random variable. For example, the - `Indpendent` distribution composed of a collection of `Bernoulli` + `Independent` distribution composed of a collection of `Bernoulli` distributions might define a distribution over an image (where each `Bernoulli` is a distribution over each pixel). diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 120b38db3cf72e8fce56a7e9293cdf25e75784e2..192dede6ff1d4de8d4be9965c414e7453d7b5d4b 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -44,18 +44,16 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. - Derivation from [1] and Euler's constant [2]. - [1] - - https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers - [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant - + Derivation from [here]( + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers) + and [Euler's constant]( + https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant). Args: x: input float. Returns: z: The analytic continuation of the harmonic number for the input. - """ one = array_ops.ones([], dtype=x.dtype) return math_ops.digamma(x + one) - math_ops.digamma(one) diff --git a/tensorflow/contrib/distributions/python/ops/moving_stats.py b/tensorflow/contrib/distributions/python/ops/moving_stats.py index 20f85643b9e7db61b4786dffe4115c7d3c00b046..87d40805a3c7a9c2871305af7f7182b7e2923530 100644 --- a/tensorflow/contrib/distributions/python/ops/moving_stats.py +++ b/tensorflow/contrib/distributions/python/ops/moving_stats.py @@ -47,9 +47,7 @@ def assign_moving_mean_variance( Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-1 mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Args: mean_var: `float`-like `Variable` representing the exponentially weighted @@ -72,6 +70,12 @@ def assign_moving_mean_variance( TypeError: if `mean_var` does not have float type `dtype`. TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with ops.name_scope(name, "assign_moving_mean_variance", [variance_var, mean_var, value, decay]): @@ -183,9 +187,7 @@ def moving_mean_variance(value, decay, collections=None, name=None): Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-`1` mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Unlike `assign_moving_mean_variance`, this function handles variable creation. @@ -208,6 +210,12 @@ def moving_mean_variance(value, decay, collections=None, name=None): Raises: TypeError: if `value_var` does not have float type `dtype`. TypeError: if `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 46c2cc8b7a8c536a90176fbb2b2d52fed61e4705..e3e40b2e9ca232b9970768f21fb95887fdf0df2d 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -52,7 +52,7 @@ class OneHotCategorical(distribution.Distribution): #### Examples - Creates a 3-class distiribution, with the 2nd class, the most likely to be + Creates a 3-class distribution, with the 2nd class, the most likely to be drawn from. ```python @@ -60,7 +60,7 @@ class OneHotCategorical(distribution.Distribution): dist = OneHotCategorical(probs=p) ``` - Creates a 3-class distiribution, with the 2nd class the most likely to be + Creates a 3-class distribution, with the 2nd class the most likely to be drawn from, using logits. ```python diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index b525809015537ac8c7ee701c100fba6541fe2e92..e454a53c6275e0c60edd8c87b1c3be670f2b22de 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -35,10 +35,10 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): The RelaxedBernoulli is a distribution over the unit interval (0,1), which continuously approximates a Bernoulli. The degree of approximation is - controlled by a temperature: as the temperaturegoes to 0 the RelaxedBernoulli - becomes discrete with a distribution described by the `logits` or `probs` - parameters, as the temperature goes to infinity the RelaxedBernoulli - becomes the constant distribution that is identically 0.5. + controlled by a temperature: as the temperature goes to 0 the + RelaxedBernoulli becomes discrete with a distribution described by the + `logits` or `probs` parameters, as the temperature goes to infinity the + RelaxedBernoulli becomes the constant distribution that is identically 0.5. The RelaxedBernoulli distribution is a reparameterized continuous distribution that is the binary special case of the RelaxedOneHotCategorical diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index ff33f327c7a77597e516208cacad8c4aed65d1c9..f56ba0781604cb5a4fb3070b79aa86e09ceb6766 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -303,7 +303,7 @@ class RelaxedOneHotCategorical( The RelaxedOneHotCategorical is a distribution over random probability vectors, vectors of positive real values that sum to one, which continuously approximates a OneHotCategorical. The degree of approximation is controlled by - a temperature: as the temperaturegoes to 0 the RelaxedOneHotCategorical + a temperature: as the temperature goes to 0 the RelaxedOneHotCategorical becomes discrete with a distribution described by the `logits` or `probs` parameters, as the temperature goes to infinity the RelaxedOneHotCategorical becomes the constant distribution that is identically the constant vector of diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index dfc813361977c159d8d48f9d5b9ff03db5b4acdc..f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -301,13 +302,16 @@ def percentile(x, with ops.name_scope(name, [x, q]): x = ops.convert_to_tensor(x, name="x") - q = math_ops.to_float(q, name="q") + # Double is needed here and below, else we get the wrong index if the array + # is huge along axis. + q = math_ops.to_double(q, name="q") _get_static_ndims(q, expect_ndims=0) if validate_args: q = control_flow_ops.with_dependencies([ - check_ops.assert_rank(q, 0), check_ops.assert_greater_equal(q, 0.), - check_ops.assert_less_equal(q, 100.) + check_ops.assert_rank(q, 0), + check_ops.assert_greater_equal(q, math_ops.to_double(0.)), + check_ops.assert_less_equal(q, math_ops.to_double(100.)) ], q) if axis is None: @@ -332,7 +336,7 @@ def percentile(x, y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. - d = math_ops.to_float(array_ops.shape(y)[-1]) + d = math_ops.to_double(array_ops.shape(y)[-1]) if interpolation == "lower": index = math_ops.ceil((d - 1) * frac_at_q_or_above) @@ -341,12 +345,18 @@ def percentile(x, elif interpolation == "nearest": index = math_ops.round((d - 1) * frac_at_q_or_above) + # If d is gigantic, then we would have d == d - 1, even in double... So + # let's use max/min to avoid out of bounds errors. + d = array_ops.shape(y)[-1] + # d - 1 will be distinct from d in int32. + index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1) + # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) # result.shape = B - result = sorted_y[..., math_ops.to_int32(index)] + result = sorted_y[..., index] result.set_shape(y.get_shape()[:-1]) if keep_dims: diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 5fb6f0c7eaa8c4734ea4c161b0eee6f24d4c9850..bac0b79d5908712f4e64259768fb6f3b4558f620 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -32,45 +32,50 @@ from tensorflow.python.ops.distributions import util as distribution_util class _DistributionShape(object): """Manage and manipulate `Distribution` shape. - Terminology: - Recall that a `Tensor` has: - - `shape`: size of `Tensor` dimensions, - - `ndims`: size of `shape`; number of `Tensor` dimensions, - - `dims`: indexes into `shape`; useful for transpose, reduce. - - `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, - `batch_dims`, and `event_dims`. To understand the semantics of these - dimensions, consider when two of the three are fixed and the remaining - is varied: - - `sample_dims`: indexes independent draws from identical - parameterizations of the `Distribution`. - - `batch_dims`: indexes independent draws from non-identical - parameterizations of the `Distribution`. - - `event_dims`: indexes event coordinates from one sample. - - The `sample`, `batch`, and `event` dimensions constitute the entirety of a - `Distribution` `Tensor`'s shape. - - The dimensions are always in `sample`, `batch`, `event` order. - - Purpose: - This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into - `Distribution` notions of `sample,` `batch,` and `event` dimensions. That - is, it computes any of: + #### Terminology - ``` - sample_shape batch_shape event_shape - sample_dims batch_dims event_dims - sample_ndims batch_ndims event_ndims - ``` + Recall that a `Tensor` has: + - `shape`: size of `Tensor` dimensions, + - `ndims`: size of `shape`; number of `Tensor` dimensions, + - `dims`: indexes into `shape`; useful for transpose, reduce. + + `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, + `batch_dims`, and `event_dims`. To understand the semantics of these + dimensions, consider when two of the three are fixed and the remaining + is varied: + - `sample_dims`: indexes independent draws from identical + parameterizations of the `Distribution`. + - `batch_dims`: indexes independent draws from non-identical + parameterizations of the `Distribution`. + - `event_dims`: indexes event coordinates from one sample. + + The `sample`, `batch`, and `event` dimensions constitute the entirety of a + `Distribution` `Tensor`'s shape. + + The dimensions are always in `sample`, `batch`, `event` order. + + #### Purpose + + This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into + `Distribution` notions of `sample,` `batch,` and `event` dimensions. That + is, it computes any of: + + ``` + sample_shape batch_shape event_shape + sample_dims batch_dims event_dims + sample_ndims batch_ndims event_ndims + ``` - for a given `Tensor`, e.g., the result of - `Distribution.sample(sample_shape=...)`. + for a given `Tensor`, e.g., the result of + `Distribution.sample(sample_shape=...)`. - For a given `Tensor`, this class computes the above table using minimal - information: `batch_ndims` and `event_ndims`. + For a given `Tensor`, this class computes the above table using minimal + information: `batch_ndims` and `event_ndims`. + + #### Examples + + We show examples of distribution shape semantics. - Examples of `Distribution` `shape` semantics: - Sample dimensions: Computing summary statistics, i.e., the average is a reduction over sample dimensions. @@ -111,52 +116,54 @@ class _DistributionShape(object): tf.div(1., tf.reduce_prod(x, event_dims)) ``` - Examples using this class: - Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. - - ```python - # 150 iid samples from one multivariate Normal with two degrees of freedom. - mu = [0., 0] - sigma = [[1., 0], - [0, 1]] - mvn = MultivariateNormal(mu, sigma) - rand_mvn = mvn.sample(sample_shape=[3, 50]) - shaper = DistributionShape(batch_ndims=0, event_ndims=1) - S, B, E = shaper.get_shape(rand_mvn) - # S = [3, 50] - # B = [] - # E = [2] - - # 12 iid samples from one Wishart with 2x2 events. - sigma = [[1., 0], - [2, 1]] - wishart = Wishart(df=5, scale=sigma) - rand_wishart = wishart.sample(sample_shape=[3, 4]) - shaper = DistributionShape(batch_ndims=0, event_ndims=2) - S, B, E = shaper.get_shape(rand_wishart) - # S = [3, 4] - # B = [] - # E = [2, 2] - - # 100 iid samples from two, non-identical trivariate Normal distributions. - mu = ... # shape(2, 3) - sigma = ... # shape(2, 3, 3) - X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) - # S = [4, 25] - # B = [2] - # E = [3] - ``` - - Argument Validation: - When `validate_args=False`, checks that cannot be done during - graph construction are performed at graph execution. This may result in a - performance degradation because data must be switched from GPU to CPU. - - For example, when `validate_args=False` and `event_ndims` is a - non-constant `Tensor`, it is checked to be a non-negative integer at graph - execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` - arguments are always checked for correctness since this can be done for - "free," i.e., during graph construction. + We show examples using this class. + + Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. + + ```python + # 150 iid samples from one multivariate Normal with two degrees of freedom. + mu = [0., 0] + sigma = [[1., 0], + [0, 1]] + mvn = MultivariateNormal(mu, sigma) + rand_mvn = mvn.sample(sample_shape=[3, 50]) + shaper = DistributionShape(batch_ndims=0, event_ndims=1) + S, B, E = shaper.get_shape(rand_mvn) + # S = [3, 50] + # B = [] + # E = [2] + + # 12 iid samples from one Wishart with 2x2 events. + sigma = [[1., 0], + [2, 1]] + wishart = Wishart(df=5, scale=sigma) + rand_wishart = wishart.sample(sample_shape=[3, 4]) + shaper = DistributionShape(batch_ndims=0, event_ndims=2) + S, B, E = shaper.get_shape(rand_wishart) + # S = [3, 4] + # B = [] + # E = [2, 2] + + # 100 iid samples from two, non-identical trivariate Normal distributions. + mu = ... # shape(2, 3) + sigma = ... # shape(2, 3, 3) + X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) + # S = [4, 25] + # B = [2] + # E = [3] + ``` + + #### Argument Validation + + When `validate_args=False`, checks that cannot be done during + graph construction are performed at graph execution. This may result in a + performance degradation because data must be switched from GPU to CPU. + + For example, when `validate_args=False` and `event_ndims` is a + non-constant `Tensor`, it is checked to be a non-negative integer at graph + execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` + arguments are always checked for correctness since this can be done for + "free," i.e., during graph construction. """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 0c747f8e68529484ae6f695b8500cde74857bb11..971d65c4a69140161461fdac93bb588014dd3e88 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -181,7 +181,7 @@ def quadrature_scheme_softmaxnormal_quantiles( edges = array_ops.reshape(edges, shape=array_ops.concat([ [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) quantiles = dist.quantile(edges) - quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + quantiles = SoftmaxCentered().forward(quantiles) # Cyclically permute left by one. perm = array_ops.concat([ math_ops.range(1, 1 + batch_ndims), [0]], axis=0) @@ -248,11 +248,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of the quantiles of `p(z)` (generalized quantiles if `K > 2`). - See [1] for more details. - - [1]. "Quadrature Compound: An approximating family of distributions" - Joshua Dillon, Ian Langmore, arXiv preprints - https://arxiv.org/abs/1801.03080 + See [Dillon and Langmore (2018)][1] for more details. #### About `Vector` distributions in TensorFlow. @@ -313,6 +309,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): is_positive_definite=True), ], validate_args=True) + ``` + + #### References + + [1]: Joshua Dillon and Ian Langmore. Quadrature Compound: An approximating + family of distributions. _arXiv preprint arXiv:1801.03080_, 2018. + https://arxiv.org/abs/1801.03080 """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 8c67647a618d22a58428d78865c4ebf7d98bdf9e..887981d64ef077e2636f8031581c390f177edac8 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -66,7 +66,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): This distribution is an Affine transformation of iid [Student's t-distributions]( https://en.wikipedia.org/wiki/Student%27s_t-distribution) - and should not be confused with the [Multivate Student's t-distribution]( + and should not be confused with the [Multivariate Student's t-distribution]( https://en.wikipedia.org/wiki/Multivariate_t-distribution). The traditional Multivariate Student's t-distribution is type of [elliptical distribution]( diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index e4ac65012b9c7e3ed5ada3ed75020f3905740156..5a8c94dabf4c3c430bee544a48ee7acfe7dd7ed0 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -228,9 +228,12 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) + expanded_df = self.df * array_ops.ones( + self.scale_operator.batch_shape_tensor(), + dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( - 0.5 * self.df, self.dimension), + 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa..9a3b780af888a597d2440b243ffb8dc98d764f18 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,12 +1,8 @@ # Eager Execution -> *WARNING*: This is a preview/pre-alpha version. The API and performance -> characteristics are subject to change. - -Eager execution is an experimental interface to TensorFlow that provides an -imperative programming style (à la [NumPy](http://www.numpy.org)). When you -enable eager execution, TensorFlow operations execute immediately; you do not -execute a pre-constructed graph with +Eager execution provides an imperative interface to TensorFlow (similiar to +[NumPy](http://www.numpy.org)). When you enable eager execution, TensorFlow +operations execute immediately; you do not execute a pre-constructed graph with [`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session). For example, consider a simple computation in TensorFlow: @@ -33,7 +29,7 @@ print(m) ## Caveats This feature is in early stages and work remains to be done in terms of smooth -support for distributed and multi-GPU training and CPU performance. +support for distributed and multi-GPU training and performance. - [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Acomp%3Aeager) - Feedback is welcome, please consider @@ -41,21 +37,23 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Eager execution is included in TensorFlow versions 1.5 and above. +Eager execution is included in TensorFlow versions 1.7 and above. Installation instructions at https://www.tensorflow.org/install/ ## Documentation For an introduction to eager execution in TensorFlow, see: -- [User Guide](python/g3doc/guide.md) +- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) ## Changelog -- 2017/10/31: Initial preview release. +- 2017/10/31: Initial preview release (in TensorFlow 1.5) - 2017/12/01: Example of dynamic neural network: [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). See [README.md](python/examples/spinn/README.md) for details. +- 2017/03: Core functionality moved out of the experimental tf.contrib namespace + in TensorFlow 1.7. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD index aedfec8924e7314addd22349c0576a84a58d9aa3..b016d2dcb504044372c895e1eedf3511751bc13e 100644 --- a/tensorflow/contrib/eager/proto/BUILD +++ b/tensorflow/contrib/eager/proto/BUILD @@ -4,17 +4,6 @@ exports_files(["LICENSE"]) load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "checkpointable_object_graph_proto", srcs = [ diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 384ef7f9630647714b77825b54b3b8a3abdfa6f3..edb9130266e4ea93d2ec6ee373a90df504da18cf 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -70,6 +70,7 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + ":checkpointable_utils", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", @@ -79,6 +80,7 @@ cuda_py_test( "//tensorflow/python/data", "//tensorflow/python/eager:test", ], + tags = ["noguitar"], ) py_library( @@ -232,12 +234,15 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", - "//tensorflow/python:io_ops", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", ], @@ -267,20 +272,8 @@ cuda_py_test( "//tensorflow/python/keras", ], tags = [ + "no_oss", # b/74395663 "no_windows", # TODO: needs investigation on Windows "notsan", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index d07121df635cc95402a4811f810007807dfa0c37..34cb8d0e0887bd5e440873bae117bf27597de11b 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc import collections +import functools import weakref from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 @@ -32,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import io_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable as core_checkpointable @@ -220,12 +220,16 @@ def _serialize_checkpointables( object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] - for name, saveable in ( + for name, saveable_factory in ( checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) + if callable(saveable_factory): + saveable = saveable_factory(name=attribute.checkpoint_key) + else: + saveable = saveable_factory # Figure out the name-based Saver's name for this variable. saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( [saveable], convert_variable_to_tensor=False) @@ -519,6 +523,18 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface): fetches=fetches, feed_dict=feed_dict, **kwargs) +def _copy_saver_with_new_var_list(old_saver, new_var_list): + """Copy a `tf.train.Saver`'s state to a new Saver with different variables.""" + new_saver = saver_lib.Saver(var_list=new_var_list) + # TODO(allenl): Move to copying functionality to Saver? + # pylint: disable=protected-access + new_saver._last_checkpoints = old_saver._last_checkpoints + new_saver._checkpoints_to_be_deleted = old_saver._checkpoints_to_be_deleted + new_saver._next_checkpoint_time = old_saver._next_checkpoint_time + # pylint: enable=protected-access + return new_saver + + class CheckpointableSaver(object): """Saves and restores a `Checkpointable` object and its dependencies. @@ -561,7 +577,6 @@ class CheckpointableSaver(object): self._last_save_saver = None # Op caching for restore - self._object_graph_restore_tensor = None self._last_restore_object_graph = None self._last_restore_checkpoint = None @@ -598,8 +613,7 @@ class CheckpointableSaver(object): """ named_variables, graph_proto = _serialize_object_graph( self._root_checkpointable) - in_graph_mode = not context.executing_eagerly() - if in_graph_mode: + if not context.executing_eagerly(): if session is None: session = ops.get_default_session() if self._object_graph_feed_tensor is None: @@ -618,21 +632,20 @@ class CheckpointableSaver(object): named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( tensor=object_graph_tensor, name=_OBJECT_GRAPH_PROTO_KEY) - if not in_graph_mode or self._last_save_object_graph != graph_proto: - if self._last_save_object_graph is not None and in_graph_mode: - raise NotImplementedError( - "Using a single Saver to save a mutated object graph is not " - "currently supported when graph building. Use a different Saver " - "when the object graph changes (save ops will be duplicated), or " - "file a feature request if this limitation bothers you.") - saver = saver_lib.Saver(var_list=named_variables) - if in_graph_mode: - self._last_save_saver = saver - self._last_save_object_graph = graph_proto - else: - saver = self._last_save_saver + if (self._last_save_object_graph != graph_proto + # When executing eagerly, we need to re-create SaveableObjects each time + # save() is called so they pick up new Tensors passed to their + # constructors. That means the Saver needs to be copied with a new + # var_list. + or context.executing_eagerly()): + if self._last_save_object_graph is not None: + self._last_save_saver = _copy_saver_with_new_var_list( + old_saver=self._last_save_saver, new_var_list=named_variables) + else: + self._last_save_saver = saver_lib.Saver(var_list=named_variables) + self._last_save_object_graph = graph_proto with ops.device("/cpu:0"): - save_path = saver.save( + save_path = self._last_save_saver.save( sess=_SessionWithFeedDictAdditions( session=session, feed_additions=feed_additions), save_path=file_prefix, @@ -651,7 +664,7 @@ class CheckpointableSaver(object): attribute_proto.checkpoint_key] return saver_names - def restore(self, save_path, session=None): + def restore(self, save_path): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks @@ -661,8 +674,7 @@ class CheckpointableSaver(object): constructor after this call will be matched if they have a corresponding object in the checkpoint. - When building a graph, restorations are added to the graph but not run. A - session is required to retrieve checkpoint metadata. + When building a graph, restorations are added to the graph but not run. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: @@ -700,9 +712,6 @@ class CheckpointableSaver(object): object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.train.Saver`, names are used to match variables. - session: The session to retrieve metadata with. Ignored when executing - eagerly. If not provided when graph building, the default session is - used. Returns: A load status object, which can be used to make assertions about the @@ -717,32 +726,15 @@ class CheckpointableSaver(object): return InitializationOnlyStatus(self._root_checkpointable) in_graph_mode = not context.executing_eagerly() if in_graph_mode: - if session is None: - session = ops.get_default_session() file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: - session = None with ops.device("/cpu:0"): file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None + reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: - if not in_graph_mode or self._object_graph_restore_tensor is None: - with ops.device("/cpu:0"): - object_graph_string, = io_ops.restore_v2( - prefix=file_prefix_tensor, - tensor_names=[_OBJECT_GRAPH_PROTO_KEY], - shape_and_slices=[""], - dtypes=[dtypes.string], - name="object_graph_proto_read") - if in_graph_mode: - self._object_graph_restore_tensor = object_graph_string - if in_graph_mode: - object_graph_string = session.run( - self._object_graph_restore_tensor, - feed_dict=file_prefix_feed_dict) - else: - object_graph_string = object_graph_string.numpy() + object_graph_string = reader.get_tensor(_OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. @@ -757,7 +749,6 @@ class CheckpointableSaver(object): if in_graph_mode: dtype_map = None else: - reader = pywrap_tensorflow.NewCheckpointReader(save_path) dtype_map = reader.get_variable_to_dtype_map() checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, @@ -877,3 +868,115 @@ class Checkpoint(core_checkpointable.Checkpointable): # initialization when executing eagerly. self._maybe_create_save_counter() return status + + +class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + """Wraps save and restore callbacks as a `SaveableObject`.""" + + def __init__(self, name, dtype, save_callback, restore_callback): + self._restore_callback = restore_callback + spec = saver_lib.BaseSaverBuilder.SaveSpec( + tensor=save_callback, + slice_spec="", + name=name, + dtype=dtype) + super(_CallbackSaveable, self).__init__( + save_callback, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return self._restore_callback(tensor) + + +class _SplitDependency(core_checkpointable.CheckpointableBase): + """Looks like a regular variable while synchronizing save/restores.""" + + def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, + fill_save_buffer_fn, consume_restore_buffer_fn): + self._save_buffer = save_buffer + self._restore_buffer = restore_buffer + self._name = name + self._dtype = dtype + self._num_components = num_components + self._fill_save_buffer_fn = fill_save_buffer_fn + self._consume_restore_buffer_fn = consume_restore_buffer_fn + + def _save(self): + """Pull from the shared buffer, populating it if necessary.""" + if self._name not in self._save_buffer: + if self._save_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be saved together.") % (self._name, self)) + self._fill_save_buffer_fn(self._save_buffer) + return self._save_buffer.pop(self._name) + + def _restore(self, tensor): + """Push into the shared buffer, flushing it if necessary.""" + if self._name in self._restore_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be restored together.") % (self._name, self)) + self._restore_buffer[self._name] = tensor + if len(self._restore_buffer) == self._num_components: + op = self._consume_restore_buffer_fn(self._restore_buffer) + self._restore_buffer.clear() + return op + else: + return control_flow_ops.no_op() + + def _gather_saveables_for_checkpoint(self): + """Looks to Checkpointable like a regular variable.""" + return { + core_checkpointable.VARIABLE_VALUE_KEY: + functools.partial(_CallbackSaveable, + dtype=self._dtype, + save_callback=self._save, + restore_callback=self._restore) + } + + +def split_dependency(component_names, component_dtypes, + fill_save_buffer_fn, consume_restore_buffer_fn): + """Creates multiple dependencies with a synchronized save/restore. + + Useful when a single op produces `Tensor`s which should each be saved under + different objects, or when `Tensor`s saved with many different objects need to + be restored together as inputs to a single op (i.e. an object which uses a + single fused op may be swapped out for a subgraph of objects, and these two + programs are checkpoint compatible). + + Args: + component_names: A sequence of names for the split + dependencies. `fill_save_buffer_fn` must add these keys to the dictionary + it is passed, and `consume_restore_buffer_fn` will receive a dictionary + with these keys. + component_dtypes: Data types for the `Tensor`s being saved and restored, a + sequence corresponding to `component_names`. + fill_save_buffer_fn: A function which takes an empty dictionary as an + argument and adds `Tensor`s with `component_names` as keys. These + `Tensor`s will be saved as if they were individual variables. + consume_restore_buffer_fn: A function which takes a dictionary with + `component_names` as keys mapping to restored individual `Tensor`s and + returns a restore op (or if executing eagerly, runs the restoration and + may return `None`). + + Returns: + A dictionary mapping from names to Checkpointable objects. If one is + reachable from an object as a dependency, the others should be too; adding + dependencies on some but not all of the objects will result in errors. + """ + save_buffer = {} + restore_buffer = {} + split_dependencies = {} + for name, dtype in zip(component_names, component_dtypes): + split_dependencies[name] = _SplitDependency( + save_buffer=save_buffer, + restore_buffer=restore_buffer, + name=name, + dtype=dtype, + num_components=len(component_names), + fill_save_buffer_fn=fill_save_buffer_fn, + consume_restore_buffer_fn=consume_restore_buffer_fn) + return split_dependencies diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index 2054878bf861553bb6cfa8d3730fa2070cf6b8bb..891c093a0f667deca6c26c453a83eca7305166a0 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -23,14 +23,18 @@ import six from tensorflow.contrib.eager.python import checkpointable_utils from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl.keras.engine import sequential from tensorflow.python.keras._impl.keras.engine import training from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops @@ -66,6 +70,87 @@ class MyModel(training.Model): return ret +def _split_variable_closure(variable): + def _fill_save_buffer_fn(save_buffer): + save_buffer["first_half"] = variable[:2] + save_buffer["second_half"] = variable[2:] + return _fill_save_buffer_fn + + +def _combine_variable_closure(variable): + def _consume_restore_buffer_fn(restore_buffer): + return variable.assign( + array_ops.concat([restore_buffer["first_half"], + restore_buffer["second_half"]], + axis=0)) + return _consume_restore_buffer_fn + + +class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): + + def __init__(self): + self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) + split_dependencies = checkpointable_utils.split_dependency( + component_names=("first_half", "second_half"), + component_dtypes=(self.combined.dtype,) * 2, + fill_save_buffer_fn=_split_variable_closure( + self.combined), + consume_restore_buffer_fn=_combine_variable_closure( + self.combined)) + for name, dep in split_dependencies.items(): + self._track_checkpointable(dep, name=name) + + +class HasRegularDeps(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class OnlyOneDep(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class SplitTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testSaveRestoreSplitDep(self): + save_checkpoint = checkpointable_utils.Checkpoint( + dep=SaveTensorSlicesAsDeps()) + self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_checkpoint.save(checkpoint_prefix) + + regular_deps = HasRegularDeps() + regular_restore_checkpoint = checkpointable_utils.Checkpoint( + dep=regular_deps) + regular_restore_checkpoint.restore( + save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) + self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) + + one_dep = OnlyOneDep() + one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + status = one_dep_restore_checkpoint.restore(save_path) + with self.assertRaises(AssertionError): + # Missing the second dependency. + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) + + restore_checkpoint = checkpointable_utils.Checkpoint() + status = restore_checkpoint.restore(save_path) + restore_checkpoint.dep = SaveTensorSlicesAsDeps() + status.assert_consumed().run_restore_ops() + self.assertAllEqual( + [1., 2., 3., 4.], + self.evaluate(restore_checkpoint.dep.combined)) + + class InterfaceTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -152,6 +237,50 @@ class InterfaceTests(test.TestCase): self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) +class _MirroringSaveable(core_saver.BaseSaverBuilder.SaveableObject): + + def __init__(self, primary_variable, mirrored_variable, name): + self._primary_variable = primary_variable + self._mirrored_variable = mirrored_variable + tensor = self._primary_variable.read_value() + spec = core_saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name) + super(_MirroringSaveable, self).__init__( + tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return control_flow_ops.group( + self._primary_variable.assign(tensor), + self._mirrored_variable.assign(tensor)) + + +class _OwnsMirroredVariables(checkpointable.CheckpointableBase): + """A Checkpointable object which returns a more complex SaveableObject.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + self.mirrored = variable_scope.get_variable( + name="mirrored", initializer=15., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + def _saveable_factory(name=self.non_dep_variable.name): + return _MirroringSaveable( + primary_variable=self.non_dep_variable, + mirrored_variable=self.mirrored, + name=name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -261,6 +390,42 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testMoreComplexSaveableReturned(self): + v = _OwnsMirroredVariables() + checkpoint = checkpointable_utils.Checkpoint(v=v) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + save_path = checkpoint.save(prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + checkpoint.restore(save_path).assert_consumed().initialize_or_restore() + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + self.evaluate(v.non_dep_variable.assign(44.)) + save_path = checkpoint.save(prefix) + self.evaluate(v.non_dep_variable.assign(45.)) + checkpoint.restore(save_path).assert_consumed().initialize_or_restore() + self.assertEqual(44., self.evaluate(v.non_dep_variable)) + self.assertEqual(44., self.evaluate(v.mirrored)) + + @test_util.run_in_graph_and_eager_modes() + def testMoreComplexSaveableReturnedWithGlobalName(self): + # The same object can also be saved using the name-based saver. + v = _OwnsMirroredVariables() + saver = core_saver.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + @test_util.run_in_graph_and_eager_modes() def testSaveRestore(self): model = MyModel() @@ -296,7 +461,11 @@ class CheckpointingTests(test.TestCase): if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() - on_create_optimizer = adam.AdamOptimizer(0.001) + on_create_optimizer = adam.AdamOptimizer( + 0.001, + # Preserve beta1_power and beta2_power when appying gradients so we can + # test that they've been restored correctly. + beta1=1.0, beta2=1.0) on_create_root = checkpointable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration @@ -313,8 +482,8 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) + dummy_var = resource_variable_ops.ResourceVariable([1.]) + on_create_optimizer.minimize(loss=dummy_var.read_value) status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) @@ -452,6 +621,35 @@ class CheckpointingTests(test.TestCase): name, = named_variables.keys() self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") + def testAnonymousVarsInInit(self): + + class Model(training.Model): + + def __init__(self): + super(Model, self).__init__() + self.w = resource_variable_ops.ResourceVariable(0.0) + self.b = resource_variable_ops.ResourceVariable(0.0) + self.vars = [self.w, self.b] + + def call(self, x): + return x * self.w + self.b + + with context.eager_mode(): + model = Model() + optimizer = adam.AdamOptimizer(learning_rate=0.05) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + checkpoint = checkpointable_utils.Checkpoint( + model=model, optimizer=optimizer) + for _ in range(2): + checkpoint.save(checkpoint_prefix) + with backprop.GradientTape() as tape: + loss = (constant_op.constant(1.) + - model(constant_op.constant(1.))) ** 2 + grad = tape.gradient(loss, model.vars) + optimizer.apply_gradients( + [(g, v) for g, v in zip(grad, model.vars)]) + @test_util.run_in_graph_and_eager_modes() def testLateDependencyTracking(self): @@ -778,6 +976,72 @@ class CheckpointingTests(test.TestCase): saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testCheckpointCleanup(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.Checkpoint(obj=obj) + for _ in range(10): + saver.save(checkpoint_prefix) + expected_filenames = ["checkpoint"] + for checkpoint_number in range(6, 11): + expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) + expected_filenames.append( + "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) + six.assertCountEqual( + self, + expected_filenames, + os.listdir(checkpoint_directory)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testCheckpointCleanupChangingVarList(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + checkpoint = checkpointable_utils.Checkpoint(obj=obj) + looped_variables = [] + for iteration in range(10): + new_variable = resource_variable_ops.ResourceVariable(iteration) + self.evaluate(new_variable.initializer) + setattr(checkpoint, "var_%d" % iteration, new_variable) + checkpoint.save(checkpoint_prefix) + looped_variables.append(new_variable) + expected_filenames = ["checkpoint"] + # We've copied the saver each time, but checkpoint management should still + # be consistent. + for checkpoint_number in range(6, 11): + expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) + expected_filenames.append( + "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) + six.assertCountEqual( + self, + expected_filenames, + os.listdir(checkpoint_directory)) + for v in looped_variables: + self.evaluate(v.assign(314)) + checkpoint.restore(checkpoint_prefix + "-6").run_restore_ops() + self.assertEqual(314, self.evaluate(checkpoint.var_9)) + self.assertEqual(314, self.evaluate(checkpoint.var_8)) + self.assertEqual(314, self.evaluate(checkpoint.var_6)) + self.assertEqual(5, self.evaluate(checkpoint.var_5)) + self.assertEqual(1, self.evaluate(checkpoint.var_1)) + self.assertEqual(0, self.evaluate(checkpoint.var_0)) + if context.executing_eagerly(): + checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + self.assertEqual(9, self.evaluate(checkpoint.var_9)) + self.assertEqual(8, self.evaluate(checkpoint.var_8)) + self.assertEqual(1, self.evaluate(checkpoint.var_1)) + self.assertEqual(0, self.evaluate(checkpoint.var_0)) + else: + # Restoring into modified graphs is an error while graph building. + with self.assertRaises(NotImplementedError): + checkpoint.restore(checkpoint_prefix + "-10").run_restore_ops() + def testManyRestoresGraph(self): """Restores after the first should not modify the graph.""" with context.graph_mode(): @@ -855,6 +1119,38 @@ class CheckpointingTests(test.TestCase): beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power)) + @test_util.run_in_graph_and_eager_modes() + def test_sequential(self): + model = sequential.Sequential() + checkpoint = checkpointable_utils.Checkpoint(model=model) + model.add(core.Dense(4)) + second_dense = core.Dense(5) + model.add(second_dense) + model(constant_op.constant([[1.]])) + checkpoint.restore(None).initialize_or_restore() + self.evaluate(second_dense.bias.assign( + constant_op.constant([1., 2., 3., 4., 5.]))) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(second_dense.bias.assign( + constant_op.constant([5., 6., 7., 8., 9.]))) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(second_dense.bias)) + + deferred_sequential = sequential.Sequential() + deferred_sequential_checkpoint = checkpointable_utils.Checkpoint( + model=deferred_sequential) + status = deferred_sequential_checkpoint.restore(save_path) + deferred_sequential.add(core.Dense(4)) + deferred_sequential(constant_op.constant([[1.]])) + deferred_second_dense = core.Dense(5) + deferred_sequential.add(deferred_second_dense) + deferred_sequential(constant_op.constant([[1.]])) + status.run_restore_ops() + self.assertAllEqual([1., 2., 3., 4., 5.], + self.evaluate(deferred_second_dense.bias)) + class TemplateTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 332bada57b42fe53fe6be0de1b39c905c0b32579..99b1e098d57ffcf028e54e7a14c36f7ba178fa45 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -31,6 +31,8 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 _uid_lock = threading.Lock() @@ -44,7 +46,7 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class Iterator(iterator_ops.EagerIterator): +class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the @@ -96,7 +98,6 @@ class Iterator(iterator_ops.EagerIterator): f=remote_fn, target_device=target, buffer_size=10, - thread_pool_size=1, container="", shared_name=_generate_shared_name("function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long @@ -106,13 +107,44 @@ class Iterator(iterator_ops.EagerIterator): def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - if self._buffer_resource_handle is not None: - with ops.device(self._device): - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - else: - return super(Iterator, self)._next_internal() + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + if self._buffer_resource_handle is not None: + with ops.device(self._device): + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + else: + return super(Iterator, self)._next_internal() + + # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset + # attributes(potential). + + class _Saveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator( + iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) + + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, + restored_tensors[0]) + + def _gather_saveables_for_checkpoint(self): + + def _saveable_factory(name): + return self._Saveable(self._resource, name) + + return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 4afadd88f59a79dde4f3af5175adbbbb18557ced..c658505de41bb6a0007440f4850fef720c3e97f1 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import threading import time @@ -24,6 +26,7 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.data.python.ops import threadpool from tensorflow.contrib.data.python.ops import unique +from tensorflow.contrib.eager.python import checkpointable_utils from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test @@ -221,6 +224,61 @@ class IteratorTest(test.TestCase): # perform work. self.assertLessEqual(len(thread_ids), num_threads) + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertAllEqual([1, 4], iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = datasets.Iterator(dataset) + iterator_2 = datasets.Iterator(dataset) + dataset_2 = Dataset.range(10) + iterator_3 = datasets.Iterator(dataset_2) + + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) + self.assertEqual(0, iterator_3.get_next().numpy()) + self.assertEqual(1, iterator_3.get_next().numpy()) + self.assertEqual(2, iterator_3.get_next().numpy()) + + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(3) + iterator = datasets.Iterator(dataset) + + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertEqual(2, iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertEqual(2, iterator.get_next().numpy()) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index 2b7e199fad08c9a5e320b51b3a4de92c2d7dbb1a..b80c90902353709b7f739585291ec3b5890c27c7 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -32,6 +32,7 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data +layers = tf.keras.layers FLAGS = None @@ -56,15 +57,15 @@ class Discriminator(tf.keras.Model): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D( + self.conv1 = layers.Conv2D( 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) - self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) - self.conv2 = tf.layers.Conv2D( + self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = layers.Conv2D( 128, 5, data_format=data_format, activation=tf.tanh) - self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) - self.flatten = tf.layers.Flatten() - self.fc1 = tf.layers.Dense(1024, activation=tf.tanh) - self.fc2 = tf.layers.Dense(1, activation=None) + self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = layers.Flatten() + self.fc1 = layers.Dense(1024, activation=tf.tanh) + self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -112,16 +113,16 @@ class Generator(tf.keras.Model): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh) + self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = tf.layers.Conv2DTranspose( + self.conv1 = layers.Conv2DTranspose( 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = tf.layers.Conv2DTranspose( + self.conv2 = layers.Conv2DTranspose( 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ab847cb78a09ab0a38beefff56f87d8314c0713..4e1380afb2e6e722de65c691d4fbf44621072e87 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -32,6 +32,8 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe +layers = tf.keras.layers + class LinearModel(tf.keras.Model): """A TensorFlow linear regression model.""" @@ -39,7 +41,7 @@ class LinearModel(tf.keras.Model): def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = tf.layers.Dense(1) + self._hidden_layer = layers.Dense(1) def call(self, xs): """Invoke the linear model. diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 6b59413141f78fc85474850e109454ecdeb68cd3..a28bc8a43d7c90737c9baf9a634d736e9de52948 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -28,6 +28,8 @@ import functools import tensorflow as tf +layers = tf.keras.layers + class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. @@ -49,23 +51,23 @@ class _IdentityBlock(tf.keras.Model): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = tf.layers.Conv2D( + self.conv2a = layers.Conv2D( filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) - self.bn2a = tf.layers.BatchNormalization( + self.bn2a = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2a') - self.conv2b = tf.layers.Conv2D( + self.conv2b = layers.Conv2D( filters2, kernel_size, padding='same', data_format=data_format, name=conv_name_base + '2b') - self.bn2b = tf.layers.BatchNormalization( + self.bn2b = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2b') - self.conv2c = tf.layers.Conv2D( + self.conv2c = layers.Conv2D( filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) - self.bn2c = tf.layers.BatchNormalization( + self.bn2c = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): @@ -113,34 +115,34 @@ class _ConvBlock(tf.keras.Model): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = tf.layers.Conv2D( + self.conv2a = layers.Conv2D( filters1, (1, 1), strides=strides, name=conv_name_base + '2a', data_format=data_format) - self.bn2a = tf.layers.BatchNormalization( + self.bn2a = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2a') - self.conv2b = tf.layers.Conv2D( + self.conv2b = layers.Conv2D( filters2, kernel_size, padding='same', name=conv_name_base + '2b', data_format=data_format) - self.bn2b = tf.layers.BatchNormalization( + self.bn2b = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2b') - self.conv2c = tf.layers.Conv2D( + self.conv2c = layers.Conv2D( filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) - self.bn2c = tf.layers.BatchNormalization( + self.bn2c = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2c') - self.conv_shortcut = tf.layers.Conv2D( + self.conv_shortcut = layers.Conv2D( filters3, (1, 1), strides=strides, name=conv_name_base + '1', data_format=data_format) - self.bn_shortcut = tf.layers.BatchNormalization( + self.bn_shortcut = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): @@ -219,15 +221,15 @@ class ResNet50(tf.keras.Model): return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - self.conv1 = tf.layers.Conv2D( + self.conv1 = layers.Conv2D( 64, (7, 7), strides=(2, 2), data_format=data_format, padding='same', name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1') - self.max_pool = tf.layers.MaxPooling2D( + self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = layers.MaxPooling2D( (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) @@ -250,11 +252,12 @@ class ResNet50(tf.keras.Model): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = tf.layers.AveragePooling2D( + self.avg_pool = layers.AveragePooling2D( (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = tf.layers.Dense(classes, name='fc1000') + self.flatten = layers.Flatten() + self.fc1000 = layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -298,7 +301,7 @@ class ResNet50(tf.keras.Model): x = self.avg_pool(x) if self.include_top: - return self.fc1000(tf.layers.flatten(x)) + return self.fc1000(self.flatten(x)) elif self.global_pooling: return self.global_pooling(x) else: diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 65dcc53aab39670cae10846b6996c17d7b4c5ba8..d6923293a374f29ab77be70fa9fea44efd1ea40b 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -64,22 +64,29 @@ def train_one_step(model, images, labels, optimizer): class ResNet50Test(tf.test.TestCase): - def _apply(self, defun=False): + def _apply(self, defun=False, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2) output = model(images, training=False) + tfe.async_wait() self.assertEqual((2, 1000), output.shape) def test_apply(self): self._apply(defun=False) + def test_apply_async(self): + self._apply(defun=False, execution_mode=tfe.ASYNC) + def test_apply_with_defun(self): self._apply(defun=True) + def test_apply_with_defun_async(self): + self._apply(defun=True, execution_mode=tfe.ASYNC) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -98,7 +105,7 @@ class ResNet50Test(tf.test.TestCase): output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - def test_train(self): + def _test_train(self, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() @@ -106,15 +113,22 @@ class ResNet50Test(tf.test.TestCase): with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) + tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') + def test_train(self): + self._test_train() + + def test_train_async(self): + self._test_train(execution_mode=tfe.ASYNC) + def test_no_garbage(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) @@ -183,59 +197,84 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False): - device, data_format = device_and_data_format() - model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) - batch_size = 64 - num_burn = 5 - num_iters = 30 - with tf.device(device): - images, _ = random_batch(batch_size) - for _ in xrange(num_burn): - model(images, training=False).cpu() - gc.collect() - start = time.time() - for _ in xrange(num_iters): - model(images, training=False).cpu() - self._report(label, start, num_iters, device, batch_size, data_format) - - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) - - def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - - def _benchmark_eager_train(self, label, make_iterator, defun=False): - device, data_format = device_and_data_format() - for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size) - num_burn = 3 - num_iters = 10 + def _benchmark_eager_apply(self, label, defun=False, execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(0.1) - + batch_size = 64 + num_burn = 5 + num_iters = 30 with tf.device(device): - iterator = make_iterator((images, labels)) + images, _ = random_batch(batch_size) for _ in xrange(num_burn): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() gc.collect() - start = time.time() for _ in xrange(num_iters): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_apply(self): + self._benchmark_eager_apply('eager_apply', defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + defun=False, + execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size) + num_burn = 3 + num_iters = 10 + model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) + optimizer = tf.train.GradientDescentOptimizer(0.1) + + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_train(self): self._benchmark_eager_train('eager_train', MockIterator, defun=False) + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + defun=False, + execution_mode=tfe.ASYNC) + def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( 'eager_train_with_defun', MockIterator, defun=True) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 29f02324544ede172500f799cd84068984d7d87b..492adbe1d80941f9df96d6636e4933d11239408e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -60,6 +60,7 @@ import functools import os import sys import time +import urllib import six import tensorflow as tf @@ -72,6 +73,8 @@ try: except ImportError: HAS_MATPLOTLIB = False +layers = tf.keras.layers + def parse(line): """Parse a line from the colors dataset.""" @@ -89,13 +92,35 @@ def parse(line): return rgb, chars, length +def maybe_download(filename, work_directory, source_url): + """Download the data from source url, unless it's already here. + + Args: + filename: string, name of the file in the directory. + work_directory: string, path to working directory. + source_url: url to download from if file doesn't exist. + + Returns: + Path to resulting file. + """ + if not tf.gfile.Exists(work_directory): + tf.gfile.MakeDirs(work_directory) + filepath = os.path.join(work_directory, filename) + if not tf.gfile.Exists(filepath): + temp_file_name, _ = urllib.request.urlretrieve(source_url) + tf.gfile.Copy(temp_file_name, filepath) + with tf.gfile.GFile(filepath) as f: + size = f.size() + print("Successfully downloaded", filename, size, "bytes.") + return filepath + + def load_dataset(data_dir, url, batch_size): """Loads the colors data at path into a PaddedDataset.""" # Downloads data at url into data_dir/basename(url). The dataset has a header # row (color_name, r, g, b) followed by comma-separated lines. - path = tf.contrib.learn.datasets.base.maybe_download( - os.path.basename(url), data_dir, url) + path = maybe_download(os.path.basename(url), data_dir, url) # This chain of commands loads our data by: # 1. skipping the header; (.skip(1)) @@ -129,7 +154,7 @@ class RNNColorbot(tf.keras.Model): self.cells = self._add_cells( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) - self.relu = tf.layers.Dense( + self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") def call(self, inputs, training=False): @@ -181,7 +206,7 @@ class RNNColorbot(tf.keras.Model): def _add_cells(self, cells): # "Magic" required for keras.Model classes to track all the variables in - # a list of tf.layers.Layer objects. + # a list of layers.Layer objects. # TODO(ashankar): Figure out API so user code doesn't have to do this. for i, c in enumerate(cells): setattr(self, "cell-%d" % i, c) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 69cd16d12c32c8c7c4744d8f0b4b1feedf946aa1..a90048d813bf345e8be32e9674a452175471b268 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -38,6 +38,8 @@ import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.eager.python import tfe +layers = tf.keras.layers + class RNN(tf.keras.Model): """A static RNN. @@ -74,14 +76,14 @@ class RNN(tf.keras.Model): def _add_cells(self, cells): # "Magic" required for keras.Model classes to track all the variables in - # a list of tf.layers.Layer objects. + # a list of Layer objects. # TODO(ashankar): Figure out API so user code doesn't have to do this. for i, c in enumerate(cells): setattr(self, "cell-%d" % i, c) return cells -class Embedding(tf.layers.Layer): +class Embedding(layers.Layer): """An Embedding layer.""" def __init__(self, vocab_size, embedding_dim, **kwargs): @@ -132,7 +134,7 @@ class PTBModel(tf.keras.Model): else: self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) - self.linear = tf.layers.Dense( + self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) self._output_shape = [-1, embedding_dim] diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 98d01ad1d5a70788d2d4cb07031a8d76a6bf628f..5966f1d4873e8e77b3ad5914da7bfc7e69d4e341 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -39,6 +39,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", ], tags = [ + "no-internal-py3", # flaky "no_cuda_on_cpu_tap", "no_pip", # because spinn.py is under third_party/. ], diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 081b0af14fcc983a3f85d2a50e2bb04d2f2493b3..9adf47d505fc2933d9c009e5863351bd123c3797 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -33,6 +33,7 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -417,12 +418,17 @@ class SpinnTest(test_util.TensorFlowTestCase): if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) - self.assertLess(train_losses[-1], train_losses[0]) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - ckpt_variable_names = [ - item[0] for item in checkpoint_utils.list_variables(config.logdir)] + object_graph_string = checkpoint_utils.load_variable( + config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") + object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() + object_graph.ParseFromString(object_graph_string) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index b73dc17e5f9cb15a51426f85e966a49604145f1d..2d2aba6908b168e0bf63f4706b6344cbb4ca82bd 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -1,892 +1,18 @@ -# TensorFlow Eager Execution - -## What is this? +# Eager execution Eager execution is a feature that makes TensorFlow execute operations -immediately: concrete values are returned, instead of a computational graph to -be executed later. - -As a result, enabling eager execution provides: - -- A [NumPy](http://www.numpy.org/)-like library for numerical computation with - support for GPU acceleration and automatic differentiation. -- A flexible platform for machine learning research and experimentation. - -Eager execution is under active development. This guide walks through an -alpha/preview release. In particular, not all TensorFlow APIs currently work -with eager execution enabled, and some models may be slow to execute, compared -to models defined without using eager execution. - -## Installation - -Eager execution is included in TensorFlow versions 1.5 and above. -Installation instructions at https://www.tensorflow.org/install/ - -The contents of this guide are compatible with TensorFlow 1.5. However, if you -run into bugs that are fixed in source but not the release, you may want to -either [build from source](https://www.tensorflow.org/install/install_sources) -or try a nightly build. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -docker pull tensorflow/tensorflow:nightly-gpu -docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. - -## Getting Started - -With TensorFlow installed, eager execution is enabled via a single call: - -```python -import tensorflow as tf - -import tensorflow.contrib.eager as tfe - -tfe.enable_eager_execution() -``` - -Enabling eager execution changes how TensorFlow functions behave (in particular, -`Tensor` objects will reference concrete values instead of being symbolic -handles to nodes in a computational graph). As a result, eager execution should -be enabled at the beginning of a program and cannot be disabled afterwards in -the same program. - -Code examples in the rest of this guide assume that eager execution has been -enabled. - -## A library for numerical computation - -A significant fraction of the [TensorFlow -API](https://www.tensorflow.org/api_docs/python/) consists of numerical -operations: -[arithmetic operations](https://www.tensorflow.org/api_guides/python/math_ops#Arithmetic_Operators), -[matrix operations](https://www.tensorflow.org/api_guides/python/math_ops#Matrix_Math_Functions), -[linear algebra operations](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), -etc. - -With eager execution enabled, these operations consume and return -multi-dimensional arrays as `Tensor` objects, similar to NumPy -[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html). -For example: - -```python -# Multiply two 2x2 matrices -x = tf.matmul([[1, 2], - [3, 4]], - [[4, 5], - [6, 7]]) -# Add one to each element -# (tf.add supports broadcasting) -y = tf.add(x, 1) - -# Create a random random 5x3 matrix -z = tf.random_uniform([5, 3]) - -print(x) -print(y) -print(z) -``` - -Output: - -``` -tf.Tensor( -[[16 19] - [36 43]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[17 20] - [37 44]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[ 0.25058532 0.0929395 0.54113817] - [ 0.3108716 0.93350542 0.84909797] - [ 0.53081679 0.12788558 0.01767385] - [ 0.29725885 0.33540785 0.83588314] - [ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32) -``` - -For convenience, these operations can also be triggered via operator overloading -of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`, -`-` to `tf.subtract`, `*` to `tf.multiply`, etc.: - -```python -x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1 -print(x) -``` - -Output: - -``` -tf.Tensor([ 3.], shape=(1,), dtype=float32) -``` - -### Converting to and from NumPy - -The operations above automatically convert Python objects (like lists of -numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used -as NumPy arrays by numpy operations. - -```python -import numpy as np - -x = tf.add(1, 1) # tf.Tensor with a value of 2 -y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2 -z = np.multiply(x, y) # numpy.int64 with a value of 4 -``` - -Alternatively, they can be explicitly converted using -[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as -shown in the next example. - -Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain -its NumPy `ndarray` value. For example: - -```python -import numpy as np - -np_x = np.array(2., dtype=np.float32) -x = tf.constant(np_x) - -py_y = 3. -y = tf.constant(py_y) - -z = x + y + 1 - -print(z) -print(z.numpy()) -``` - -Output: - -``` -tf.Tensor(6.0, shape=(), dtype=float32) -6.0 -``` - -### GPU acceleration - -Many TensorFlow operations support GPU acceleration. With eager execution -enabled, [computation is *not* automatically -offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you -must explicitly specify when GPUs should be used. - -The simplest way to do this is to enclose your computation in a `with -tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function, -which returns the number of available GPUs. - -For example, consider this snippet to measure the time to multiply two 1000x1000 -matrices on CPU: - -```python -import time - -def measure(x): - # The very first time a GPU is used by TensorFlow, it is initialized. - # So exclude the first run from timing. - tf.matmul(x, x) - - start = time.time() - for i in range(10): - tf.matmul(x, x) - end = time.time() - - return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape) - -# Run on CPU: -with tf.device("/cpu:0"): - print("CPU: %s" % measure(tf.random_normal([1000, 1000]))) - -# If a GPU is available, run on GPU: -if tfe.num_gpus() > 0: - with tf.device("/gpu:0"): - print("GPU: %s" % measure(tf.random_normal([1000, 1000]))) -``` - -Output (exact numbers will depend on the characteristics of the hardware): - -```python -CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times -GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times -``` - -Alternatively, methods on the `Tensor` object can be used to explicitly copy the -`Tensor` to a different device. Operations are typically executed on the device -on which the inputs are placed. For example: - -```python -x = tf.random_normal([10, 10]) - -x_gpu0 = x.gpu() -x_cpu = x.cpu() - -_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU -_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0 - -if tfe.num_gpus() > 1: - x_gpu1 = x.gpu(1) - _ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1 -``` - -### Automatic Differentiation - -[Automatic -differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is -very useful when implementing many machine learning algorithms (e.g., -[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training -neural networks). For this purpose, TensorFlow eager execution provides an -[autograd](https://github.com/HIPS/autograd)-style API for automatic -differentiation. Specifically, the functions: - -- `tfe.gradients_function(f)`: Returns a Python function that computes the - derivatives of the Python function `f` with respect to its arguments. `f` - must return a scalar value. When the returned function is invoked, it - returns a list of `Tensor` objects (one element for each argument of `f`). -- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`, - except that when the returned function is invoked, it returns the value of - `f` in addition to the list of derivatives of `f` with respect to its - arguments. - -These functions naturally apply to higher order differentiation as well. For -example: - -```python -def f(x): - return tf.multiply(x, x) # Or x * x -assert 9 == f(3.).numpy() - -df = tfe.gradients_function(f) -assert 6 == df(3.)[0].numpy() - -# Second order deriviative. -d2f = tfe.gradients_function(lambda x: df(x)[0]) -assert 2 == d2f(3.)[0].numpy() - -# Third order derivative. -d3f = tfe.gradients_function(lambda x : d2f(x)[0]) -assert 0 == d3f(3.)[0].numpy() -``` - -These functions can be used to train models. For example, consider the following -simple linear regression model: - -```python -def prediction(input, weight, bias): - return input * weight + bias - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# A loss function: Mean-squared error -def loss(weight, bias): - error = prediction(training_inputs, weight, bias) - training_outputs - return tf.reduce_mean(tf.square(error)) - -# Function that returns the derivative of loss with respect to -# weight and bias -grad = tfe.gradients_function(loss) - -# Train for 200 steps (starting from some random choice for W and B, on the same -# batch of data). -W = 5. -B = 10. -learning_rate = 0.01 -print("Initial loss: %f" % loss(W, B).numpy()) -for i in range(200): - (dW, dB) = grad(W, B) - W -= dW * learning_rate - B -= dB * learning_rate - if i % 20 == 0: - print("Loss at step %d: %f" % (i, loss(W, B).numpy())) -print("Final loss: %f" % loss(W, B).numpy()) -print("W, B = %f, %f" % (W.numpy(), B.numpy())) -``` - -Output: (the exact numbers may vary depending on the randomness in noise) - -``` -Initial loss: 66.730003 -Loss at step 0: 64.200096 -Loss at step 20: 29.872814 -Loss at step 40: 14.233772 -Loss at step 60: 7.090570 -Loss at step 80: 3.819887 -Loss at step 100: 2.318821 -Loss at step 120: 1.628385 -Loss at step 140: 1.310142 -Loss at step 160: 1.163167 -Loss at step 180: 1.095162 -Final loss: 1.064711 -W, B = 3.094944, 2.161383 -``` - -To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):` -block. (However, this particular model, with only two floating point parameters, -is unlikely to benefit from GPU acceleration.) - -### Customizing gradients - -One may want to define custom gradients for an operation, or for a function. -This may be useful for multiple reasons, including providing a more efficient -or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability) -gradient for a sequence of operations. - -For example, consider the function `log(1 + e^x)`, which commonly occurs in the -computation of cross entropy and log likelihoods. - -```python -def log1pexp(x): -  return tf.log(1 + tf.exp(x)) -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works fine at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# Returns a `nan` at x = 100 due to numerical instability. -import math -assert math.isnan(float(grad_log1pexp(100.)[0])) -``` - -We can define a custom gradient for the above function that analytically -simplifies the gradient expression. - -```python -@tfe.custom_gradient -def log1pexp(x): -  e = tf.exp(x) -  def grad(dy): -    return dy * (1 - 1 / (1 + e)) -  return tf.log(1 + e), grad -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works as before at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# But now works at x = 100 as well. -assert 1.0 == float(grad_log1pexp(100.)[0]) -``` -Also notice how the gradient function implementation reuses an expression -(`tf.exp(x)`) computed during the forward pass, hence making the gradient -computation more efficient by avoiding redundant computation. - -## Building and training models - -In practice, your computation may have many parameters to be optimized (by -computing derivatives). Encapsulating them into re-usable classes/objects -makes the code easier to follow than writing a single top-level function with -many arguments. - -In fact, eager execution encourages use of the [Keras](https://keras.io)-style -"Layer" classes in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. - -Furthermore, you may want to apply more sophisticated techniques to compute -parameter updates, such as those in -[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers) -implementations. - -This next section walks through using the same `Optimizer` and `Layer` APIs used -to build trainable TensorFlow graphs in an environment where eager execution is -enabled. - -### Variables and Optimizers - -`tfe.Variable` objects store mutable `Tensor` values that can be accessed during -training, making automatic differentiation easier. In particular, parameters of -a model can be encapsulated in Python classes as variables. - -`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f` -with respect to its arguments. However, it requires all parameters of interest -to be arguments of `f`, which becomes cumbersome when `f` depends on a large -number of trainable parameters. - -`tfe.implicit_gradients` is an alternative function with some useful properties: - -- It computes the derivatives of `f` with respect to all the `tfe.Variable`s - used by `f`. -- When the returned function is invoked, it returns a list of - (gradient value, Variable object) tuples. - -Representing model parameters as `Variable` objects, along with the use of -`tfe.implicit_gradients`, typically results in better encapsulation. For -example, the linear regression model described above can be written into a -class: - -```python -class Model(object): - def __init__(self): - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') - - def predict(self, inputs): - return inputs * self.W + self.B - - -# The loss function to be optimized -def loss(model, inputs, targets): - error = model.predict(inputs) - targets - return tf.reduce_mean(tf.square(error)) - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# Define: -# 1. A model -# 2. Derivatives of a loss function with respect to model parameters -# 3. A strategy for updating the variables based on the derivatives -model = Model() -grad = tfe.implicit_gradients(loss) -optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - -# The training loop -print("Initial loss: %f" % - loss(model, training_inputs, training_outputs).numpy()) -for i in range(201): - optimizer.apply_gradients(grad(model, training_inputs, training_outputs)) - if i % 20 == 0: - print("Loss at step %d: %f" % - (i, loss(model, training_inputs, training_outputs).numpy())) -print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy()) -print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy())) -``` - -Output: - -``` -Initial loss: 69.693184 -Loss at step 0: 66.987854 -Loss at step 20: 30.553387 -Loss at step 40: 14.250237 -Loss at step 60: 6.955020 -Loss at step 80: 3.690550 -Loss at step 100: 2.229739 -Loss at step 120: 1.576032 -Loss at step 140: 1.283496 -Loss at step 160: 1.152584 -Loss at step 180: 1.093999 -Final loss: 1.067780 -W, B = 3.0114281, 2.0865183 -``` - -Using `implicit_gradients` avoids the need to provide all the trainable -parameters of the model as arguments to the `loss` function. - -### Using Keras and the Layers API - -[Keras](https://keras.io) is a popular API for defining model structures. The -[`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) -module provides a set of building blocks for models and is implemented using the -`tf.layers.Layer` subclasses in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. We encourage the use of these same building blocks when using -TensorFlow's eager execution feature. For example, the very same linear -regression model can be built using `tf.layers.Dense`: - -```python -class Model(object): - def __init__(self): - self.layer = tf.layers.Dense(1) - - def predict(self, inputs): - return self.layer(inputs) -``` - -The `tf.layers` API makes it more convenient to define more sophisticated -models. For example, the following will train an MNIST model: - -```python -class MNISTModel(object): - def __init__(self, data_format): - # 'channels_first' is typically faster on GPUs - # while 'channels_last' is typically faster on CPUs. - # See: https://www.tensorflow.org/performance/performance_guide#data_formats - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D(32, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.max_pool2d = tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='same', data_format=data_format) - self.conv2 = tf.layers.Conv2D(64, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu) - self.dropout = tf.layers.Dropout(0.5) - self.dense2 = tf.layers.Dense(10) - - def predict(self, inputs): - x = tf.reshape(inputs, self._input_shape) - x = self.max_pool2d(self.conv1(x)) - x = self.max_pool2d(self.conv2(x)) - x = tf.layers.flatten(x) - x = self.dropout(self.dense1(x)) - return self.dense2(x) - -def loss(model, inputs, targets): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=model.predict(inputs), labels=targets)) - - -# Load the training and validation data -from tensorflow.examples.tutorials.mnist import input_data -data = input_data.read_data_sets("./mnist_data", one_hot=True) - -# Train -device = "gpu:0" if tfe.num_gpus() else "cpu:0" -model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last') -optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) -grad = tfe.implicit_gradients(loss) -for i in range(20001): - with tf.device(device): - (inputs, targets) = data.train.next_batch(50) - optimizer.apply_gradients(grad(model, inputs, targets)) - if i % 100 == 0: - print("Step %d: Loss on training set : %f" % - (i, loss(model, inputs, targets).numpy())) -print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy()) -``` - -For a more complete example, see [the example in the tensorflow/models -repository](https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py). - -### Checkpointing trained variables - -TensorFlow Variables (`tfe.Variable`) provide a way to represent shared, -persistent state of your model. The `tfe.Checkpoint` class provides a means to -save and restore variables to and from _checkpoints_. - -For example: - -```python -# Create variables. -x = tfe.Variable(10.) -y = tfe.Variable(5.) - -# Indicate that the variables should be saved as "x" and "y". -checkpoint = tfe.Checkpoint(x=x, y=y) - -# Assign new values to the variables and save. -x.assign(2.) -checkpoint.save('/tmp/ckpt') - -# Change the variable after saving. -x.assign(11.) -assert 16. == (x + y).numpy() # 11 + 5 - -# Restore the values in the checkpoint. -checkpoint.restore('/tmp/ckpt-1') - -assert 7. == (x + y).numpy() # 2 + 5 -``` - -### `tf.keras.Model` - -You may often want to organize your models using classes, like the `MNISTModel` -class described above. We recommend inheriting from the `tf.keras.Model` class -as it provides conveniences like keeping track of all model variables. - -Sub-classes of `tf.keras.Model` may register `Layer`s (like classes in -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), or [Keras -layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) by -assigning them to attributes (`self.name = layer_object`) and define the -computation in an implementation of `call()`. - -Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables -lazily, when the first input is encountered. - -For example, consider the following two-layer neural network: - -```python -class TwoLayerNet(tf.keras.Model): - def __init__(self): - super(TwoLayerNet, self).__init__() - self.layer1 = tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False) - self.layer2 = tf.layers.Dense(3, use_bias=False) - - def call(self, x): - return self.layer2(self.layer1(x)) - -net = TwoLayerNet() - -# No variables created yet -assert 0 == len(net.variables) - -# They are created on first input: -inp = tf.constant([[1.]]) - -# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units, -# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3 -# matrix. -assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape. -assert 1 == len(net.layer1.variables) -assert 1 == len(net.layer2.variables) -assert 2 == len(net.variables) # weights for each layer. -assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1. -assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2. -``` - -The `tf.keras.Model` class is itself a sub-class of `tf.layers.Layer`. This -allows instances of `tf.keras.Model` to be embedded in other models. For -example: - -```python -class ThreeLayerNet(tf.keras.Model): - def __init__(self): - super(ThreeLayerNet, self).__init__() - self.a = TwoLayerNet() - self.b = tf.layers.Dense(4, use_bias=False) - - def call(self, x): - return self.b(self.a(x)) - -net = ThreeLayerNet() - -assert [1, 4] == net(inp).shape.as_list() -assert 3 == len(net.variables) -assert [1, 2] == net.variables[0].shape.as_list() -assert [2, 3] == net.variables[1].shape.as_list() -assert [3, 4] == net.variables[2].shape.as_list() -``` - -See more examples in -[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples). - -`tfe.Checkpoint` provides a convenient way to save and load training -checkpoints. Let's define something simple to train. We set an objective for the -output of our network, choose an optimizer, and a location for the checkpoint: - -```python -objective = tf.constant([[2., 3., 4., 5.]]) -optimizer = tf.train.AdamOptimizer(0.01) -checkpoint_directory = '/tmp/tfe_example' -checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') -net = ThreeLayerNet() -``` - -We group them in a `tfe.Checkpoint` and request that it be restored. This -ensures that variables created by these objects are restored before their values -are used. Our training loop is the same whether starting training or resuming -from a previous checkpoint: - -```python -global_step = tf.train.get_or_create_global_step() -checkpoint = tfe.Checkpoint( - global_step=global_step, optimizer=optimizer, network=net) -checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) -for _ in range(100): - loss_fn = lambda: tf.norm(net(inp) - objective) - optimizer.minimize(loss_fn, global_step=global_step) - if tf.equal(global_step % 20, 0): - print("Step %d, output %s" % (global_step.numpy(), - net(inp).numpy())) - # Save the checkpoint. - checkpoint.save(checkpoint_prefix) -``` - -The first time it runs, `Model` variables are initialized randomly. Then the -output is trained to match the objective we've set: - -``` -Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]] -Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]] -Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]] -Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]] -Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]] -``` - -In subsequent iterations, variables are initialized with the values read from -the latest checkpoint. Running the same code again, we continue from where we -left off: - -``` -Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]] -Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]] -Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]] -Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]] -Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]] -``` - - -### Summaries, metrics and TensorBoard - -[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) -is a popular tool for understanding, debugging and optimizing the model training -process. To benefit from the visualizations offered by TensorBoard, summary -events need to be written during the course of execution of your program. You -might find many Tensorflow programs that include the -[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations -during graph construction. - -`tf.summary` operations are *not* compatible with eager execution, but an -equivalent alternative exists in -[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/summary) -that is compatible with both eager execution and graph construction. - -During model construction simply insert summary operations like -`tf.contrib.summary.scalar`. These operations do nothing by default, unless a -summary writer is currently active and a writing policy is set. - -For example, to record summaries once every 100 global steps, use: - -```python -tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_file_writer(logdir) - -for _ in range(iterations): - with writer.as_default(): - with tf.contrib.summary.record_summaries_every_n_global_steps(100): - # your model code goes here - tf.contrib.summary.scalar('loss', loss) - # ... -``` - -See the full mnist example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -for a full model using `tf.contrib.summary`. - -Similarly to summaries, the metrics in `tf.metrics` are currently not compatible -with eager execution. We instead provide object-oriented metrics in the -`tfe.metrics` package, which are compatible with graph construction as well. - -Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and -`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented -interface. Here's an example of how to use the `tfe.metrics.Mean` metric: - -```python -# Metrics are objects, which can be created and destroyed. -my_mean = tfe.metrics.Mean(name='my_mean') -# While a metric is active, you can call it as a function to accumulate into its -# internal state. -my_mean(0.0) -my_mean(10.0) -# Once you've finished updating the metric, you can get its result. In this case -# a simple average over all the calls to it. If a summary writer is active the -# metric will write the appropriate summaries using the metric name. -assert 5.0 == my_mean.result().numpy() -``` - -For a full example of a model using metrics for evaluation, see the mnist -example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist). - -### Input Pipelines - -The discussion above has been centered around the computation executed by your -model. The -[`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) -module provides APIs to build complex input pipelines from simple, reusable -pieces. - -If you're familiar with constructing `tf.data.Dataset` objects when building -TensorFlow graphs, the same API calls are used when eager execution is enabled. -However, the process of iterating over elements of the dataset differs between -eager execution and graph construction. When eager execution is enabled, the -discussion on iterator creation using `make_one_shot_iterator()` and -`get_next()` in the -[Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is -*not* applicable. Instead, a more Pythonic `Iterator` class is available. - -For example: - -```python -# Create a source Dataset from in-memory numpy arrays. -# For reading from files on disk, you may want to use other Dataset classes -# like the TextLineDataset or the TFRecordDataset. -dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]) - -# Apply transformations, shuffling, batching etc. -dataset = dataset.map(tf.square).shuffle(2).batch(2) - -# Use tfe.Iterator to iterate over the dataset. -for x in tfe.Iterator(dataset): - print(x) -``` - -Output: - -``` -tf.Tensor([4 9], shape=(2,), dtype=int32) -tf.Tensor([16 25], shape=(2,), dtype=int32) -tf.Tensor([36 1], shape=(2,), dtype=int32) -``` - -## Interoperating with Graphs - -Eager execution improves the process of model development in Python; however, -because it is in its earliest stages, it does not yet support some features -available to [TensorFlow -graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph) -that are desirable when deploying models in production. In particular, eager -execution does not yet support distributed training, exporting models (to other -[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow -serving](https://www.tensorflow.org/serving/), and mobile applications), and -various memory and computation optimizations that are applied to TensorFlow's -dataflow graphs. - -That said, the APIs used to build modes are exactly the same whether executing -eagerly or constructing graphs. This means that you can iteratively develop your -model with eager execution enabled and later, if needed, use the same code to -reap the benefits of representing models as computational graphs. - -For example, the same model definition used to construct a graph in -[mnist.py`](https://github.com/tensorflow/models/tree/master/official/mnist/mnist.py) -can be trained with eager execution enabled as in [`mnist_eager.py`](https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py). - -Other models in the [examples -directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/) -demonstrate this as well. - -Some differences worth noting: - -- There is no notion of a `tf.placeholder` or a `tf.Session` when eager - execution is enabled. -- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`, - `tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution - is enabled and their use will raise an `AttributeError`. -- To use `tfe.implicit_gradients` in graph construction, variables must be - created with [`use_resource=True`] provided to - [`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable) - or - [`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope). -- Some API calls (such as the functional-style `tf.layers.dense`, - `tf.layers.conv2d`) are not compatible with eager execution. Use of such - methods should raise an error indicating the alternative (e.g., the - `tf.layers.Dense` and `tf.layers.Conv2D` classes). - -## What next? +immediately: concrete values are returned, instead of creating a computational +graph that is executed later. -Please give eager execution a spin. This feature is in early stages and is -evolving, so we welcome your feedback via issues on GitHub (see [known -issues](https://github.com/tensorflow/tensorflow/labels/comp:eager)). +A user guide is available: https://www.tensorflow.org/programmers_guide/eager +([source file](../../../../docs_src/programmers_guide/eager.md)) -You may want to browse through some sample code, including benchmarks for some: +We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). -- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) -- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) -- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) -- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) +Sample code is available, including benchmarks for some: +- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) +- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) +- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) +- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) +- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 1490c2ccacd55156bcc1cf8c07d9941336e18e1b..2f2347736a073c7d9b3fb6685f52f8d58cc40570 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -109,6 +109,18 @@ class Metric(checkpointable.CheckpointableBase): pos = scope.name.rfind(scope_name) self._name = name + scope.name[pos + len(scope_name):] self._scope = scope + + # Ensures that if the user calls build directly we still set self._built to + # True to prevent variables from being recreated. + self._build = self.build + + def actual_build(*args, **kwargs): + self._build(*args, **kwargs) + self._built = True + self.build = actual_build + self.build.__doc__ = self._build.__doc__ + + # Captures construction scope for proper initialization. if context.executing_eagerly(): self._construction_scope = context.eager_mode else: diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 6b5450ba89bdfa6e0195f488b75f596b58c463d5..15ac889191e0fe51269bc5740d5e0ab1bc0e2b72 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -195,6 +195,15 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testBuildMean(self): + # Verify that calling build() on Mean and then calling it won't recreate + # variables. + m = metrics.Mean() + m.build() + old_numer = m.numer + m(0.0) + self.assertTrue(old_numer is m.numer) + def testMetricsChain(self): with context.graph_mode(), self.test_session(): m1 = metrics.Mean() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 4c937716e8df7c8cda26d6431885ce33346b77fb..e55a9276ab53f44f76dc5e537b3bdde7c975f463 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -149,7 +149,7 @@ class Network(base.Layer): # check we might have name collisions if the parent scope on init gets # closed before build is called. self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) + variable_scope.get_variable_scope_store().variable_scopes_count) def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 5aabc9aae868021284e83a4c4d80d65c2ee63fca..c6f3f20e781147140f2c4b339ed465ab7e919d37 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -62,12 +62,18 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@executing_eagerly @@in_eager_mode +@@set_execution_mode +@@execution_mode +@@async_wait +@@async_clear_error @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT +@@SYNC +@@ASYNC """ from __future__ import absolute_import @@ -95,6 +101,12 @@ from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT from tensorflow.python.eager.context import executing_eagerly from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import set_execution_mode +from tensorflow.python.eager.context import execution_mode +from tensorflow.python.eager.context import async_wait +from tensorflow.python.eager.context import async_clear_error +from tensorflow.python.eager.context import SYNC +from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 773c6ab6c79217698c7c598a133082e2553f28f6..bec0329ebbd82b06fba6a8283500ad7f3a11b6a2 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -9,23 +9,12 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "estimator_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":boosted_trees", ":dnn", ":dnn_linear_combined", ":extenders", @@ -38,6 +27,36 @@ py_library( ], ) +py_library( + name = "boosted_trees", + srcs = ["python/estimator/boosted_trees.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:boosted_trees", + ], +) + +py_test( + name = "boosted_trees_test", + size = "medium", + srcs = ["python/estimator/boosted_trees_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":boosted_trees", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) + py_library( name = "dnn", srcs = ["python/estimator/dnn.py"], @@ -70,6 +89,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -110,6 +130,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -142,6 +163,7 @@ py_test( deps = [ ":extenders", "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/predictor", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", @@ -174,6 +196,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", @@ -245,6 +268,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -291,6 +315,8 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", @@ -354,6 +380,7 @@ cuda_py_test( size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 6b9f9575b606f1822d760e8597c55994dd8af04c..d2fc2c4bfa448227819c8d706387c1c75062b80b 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.extenders import * @@ -44,6 +45,8 @@ _allowed_symbols = [ 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', + 'boosted_trees_classifier_train_in_memory', + 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..314c54ed00372eca62ffc6930e6d492dd7d57163 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Boosted Trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees + + +class _BoostedTreesEstimator(estimator.Estimator): + """An Estimator for Tensorflow Boosted Trees models.""" + + def __init__(self, + feature_columns, + n_batches_per_layer, + head, + model_dir=None, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + config=None): + """Initializes a `BoostedTreesEstimator` instance. + + Args: + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + n_batches_per_layer: the number of batches to collect statistics per + layer. + head: the `Head` instance defined for Estimator. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + config: `RunConfig` object to configure the runtime settings. + """ + # pylint:disable=protected-access + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, labels, mode, head, feature_columns, tree_hparams, + n_batches_per_layer, config) + + super(_BoostedTreesEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + # pylint:enable=protected-access + + +def boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT, + weight_column=None, + label_vocabulary=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + config=None, + train_hooks=None): + """Trains a boosted tree classifier with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def input_fn_train(): + dataset = create-dataset-from-training-data + # Don't use repeat or cache, since it is assumed to be one epoch + # This is either tf.data.Dataset, or a tuple of feature dict and label. + return dataset + + classifier = boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + n_classes: number of label classes. Default is binary classification. + Multiclass support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Support multi-class cases. + if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT: + n_classes = 2 + head, closed_form = ( + canned_boosted_trees._create_classification_head_and_closed_form( + n_classes, weight_column, label_vocabulary=label_vocabulary)) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + closed_form_grad_and_hess_fn=closed_form, + train_in_memory=True) + + in_memory_classifier = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks) + + return in_memory_classifier + # pylint: enable=protected-access + + +def boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + config=None, + train_hooks=None): + """Trains a boosted tree regressor with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def input_fn_train(): + dataset = create-dataset-from-training-data + # Don't use repeat or cache, since it is assumed to be one epoch + # This is either tf.data.Dataset, or a tuple of feature dict and label. + return dataset + + regressor = boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + label_dimension: Number of regression targets per example. + Multi-dimensional support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Extend it to multi-dimension cases. + if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT: + label_dimension = 1 + head = canned_boosted_trees._create_regression_head(label_dimension, + weight_column) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + train_in_memory=True) + + in_memory_regressor = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks) + + return in_memory_regressor + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a87f3b3c0e7c5840fa250506e600645bf6a29 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -0,0 +1,207 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests boosted_trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.estimator.python.estimator import boosted_trees +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils + +NUM_FEATURES = 3 + +BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features. +INPUT_FEATURES = np.array( + [ + [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1] + [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1] + [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3] + ], + dtype=np.float32) +CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]] +REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]] +FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)} + + +def _make_train_input_fn(is_classification): + """Makes train input_fn for classification/regression.""" + + def _input_fn(): + features = dict(FEATURES_DICT) + if is_classification: + labels = CLASSIFICATION_LABELS + else: + labels = REGRESSION_LABELS + return features, labels + + return _input_fn + + +class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._head = canned_boosted_trees._create_regression_head(label_dimension=1) + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) + for i in range(NUM_FEATURES) + } + + def _assert_checkpoint(self, model_dir, expected_global_step): + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + def testTrainAndEvaluateEstimator(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5) + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint(est.model_dir, 11) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 0.913176) + + def testInferEstimator(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + head=self._head) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint(est.model_dir, 6) + + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertEquals(5, len(predictions)) + self.assertAllClose([0.703549], predictions[0]['predictions']) + self.assertAllClose([0.266539], predictions[1]['predictions']) + self.assertAllClose([0.256479], predictions[2]['predictions']) + self.assertAllClose([1.088732], predictions[3]['predictions']) + self.assertAllClose([1.901732], predictions[4]['predictions']) + + +class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) + for i in range(NUM_FEATURES) + } + + def _assert_checkpoint(self, model_dir, expected_global_step): + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint(est.model_dir, 6) + + # Check eval. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + + # Check predict that all labels are correct. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertEquals(5, len(predictions)) + self.assertAllClose([0], predictions[0]['class_ids']) + self.assertAllClose([1], predictions[1]['class_ids']) + self.assertAllClose([1], predictions[2]['class_ids']) + self.assertAllClose([0], predictions[3]['class_ids']) + self.assertAllClose([0], predictions[4]['class_ids']) + + +class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) + for i in range(NUM_FEATURES) + } + + def _assert_checkpoint(self, model_dir, expected_global_step): + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + def testRegressorTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint(est.model_dir, 6) + + # Check eval. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.2136638) + + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertEquals(5, len(predictions)) + self.assertAllClose([0.703549], predictions[0]['predictions']) + self.assertAllClose([0.266539], predictions[1]['predictions']) + self.assertAllClose([0.256479], predictions[2]['predictions']) + self.assertAllClose([1.088732], predictions[3]['predictions']) + self.assertAllClose([1.901732], predictions[4]['predictions']) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4..dd009a6753f3231638f93e50fc8f19eae8820139 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -34,6 +34,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -52,7 +53,9 @@ def _dnn_only_estimator_fn( config=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, dnn_feature_columns=feature_columns, dnn_optimizer=optimizer, @@ -100,7 +103,9 @@ def _linear_only_estimator_fn( partitioner=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, linear_feature_columns=feature_columns, linear_optimizer=optimizer, diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 71f810acec856d42d389260e7b9fea32123348b4..75e3107670d658e55ce23d983e47311f1c180104 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -41,7 +42,9 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 2b6881b81487dfdb682d5d6261a0318c59d461f6..266ae933052b11b9ab3edb662e95c90aae207dae 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops @@ -233,7 +234,17 @@ def forward_features(estimator, keys=None): 'argument of forward_features to filter unwanted features. Type of ' 'features[{}] is {}.'.format(key, key, type(feature))) predictions[key] = feature - return spec._replace(predictions=predictions) + spec = spec._replace(predictions=predictions) + if spec.export_outputs: + for ekey in ['predict', 'serving_default']: + if (ekey in spec.export_outputs and + isinstance(spec.export_outputs[ekey], + PredictOutput)): + export_outputs = spec.export_outputs[ekey].outputs + for key in get_keys(features): + export_outputs[key] = predictions[key] + + return spec return estimator_lib.Estimator( model_fn=new_model_fn, diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index ad1a8ef152b07ecbab33d9eb3184a2ae89def27d..407af2deaf0928361a4f0b0e44e842b7750118cb 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -18,20 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile import numpy as np from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.predictor import from_saved_model from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import training +from tensorflow.python.util import compat def get_input_fn(x, y): @@ -177,6 +184,44 @@ class ForwardFeaturesTest(test.TestCase): self.assertIn('id', predictions) self.assertEqual(101, predictions['id']) + def test_forward_in_exported(self): + + def serving_input_fn(): + features_ph = { + 'x': array_ops.placeholder(dtypes.float32, [None]), + 'id': array_ops.placeholder(dtypes.int32, [None]) + } + features = { + key: array_ops.expand_dims(tensor, -1) + for key, tensor in features_ph.items() + } + return estimator_lib.export.ServingInputReceiver(features, features_ph) + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + # create estimator + feature_columns = [fc.numeric_column('x')] + estimator = linear.LinearRegressor(feature_columns) + estimator.train(input_fn=input_fn, steps=1) + estimator = extenders.forward_features(estimator, 'id') + + # export saved model + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) + self.assertTrue(gfile.Exists(export_dir)) + + # restore model + predict_fn = from_saved_model(export_dir, signature_def_key='predict') + predictions = predict_fn({'x': [3], 'id': [101]}) + + # verify that 'id' exists in predictions + self.assertIn('id', predictions) + self.assertEqual(101, predictions['id']) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_forward_list(self): def input_fn(): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index f95fcc8039cb54c26543781b31013a7676168b0b..85ef3291bae44d3c3126d778eba718ebe15993b5 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -36,10 +36,12 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY +# TODO(b/65403806): Switch loss_reduction default to SUM_OVER_BATCH_SIZE. def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, @@ -176,7 +178,7 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, inverse_link_fn=None, name=None): @@ -216,7 +218,9 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. loss_fn: Optional loss function. Defaults to `mean_squared_error`. inverse_link_fn: Optional inverse link function, also known as 'mean function'. Defaults to identity. @@ -241,7 +245,7 @@ def regression_head(weight_column=None, def poisson_regression_head( weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, compute_full_loss=True, name=None): """Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`. @@ -273,7 +277,9 @@ def poisson_regression_head( of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. compute_full_loss: Whether to include the constant `log(z!)` term in computing the poisson loss. See `tf.nn.log_poisson_loss` for the full documentation. @@ -302,7 +308,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -353,7 +359,8 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -402,7 +409,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): self._n_classes = n_classes @@ -489,8 +496,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -502,8 +509,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -513,7 +523,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -565,8 +576,16 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -592,7 +611,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _eval_metric_ops( self, labels, probabilities, weights, unreduced_loss, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index dc30dde877ab5f912e3f6a724d481b151a3ed044..98962ca4277a3e8fbbdb3fb2d26df9acc45168b5 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -272,9 +272,9 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - expected_training_loss = np.sum( + # loss = (labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / 2 + expected_training_loss = 0.5 * np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -298,7 +298,7 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_training_loss = np.sum( + expected_training_loss = 0.5 * np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -361,7 +361,7 @@ class MultiLabelHead(test.TestCase): labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -438,12 +438,13 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -468,14 +469,13 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -533,14 +533,13 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -562,15 +561,14 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -603,8 +601,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2 + expected_loss = 12.5 spec = head.create_estimator_spec( features={ @@ -617,8 +616,8 @@ class MultiLabelHead(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { - # Average loss over weighted examples. - keys.LOSS_MEAN: expected_loss / 3, + # Average loss over weighted examples (denominator is sum(weights)). + keys.LOSS_MEAN: expected_loss * (2. / 3.), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, @@ -663,7 +662,7 @@ class MultiLabelHead(test.TestCase): # (1 - labels) * (logits > 0) * logits expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] expected_weights = [[1.], [2.]] - expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), @@ -809,11 +808,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol) def test_train(self): head = head_lib.multi_label_head(n_classes=2) @@ -823,8 +819,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -840,8 +837,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -858,11 +856,49 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_optimizer(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) @@ -916,8 +952,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2 + expected_loss = 12.5 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -951,11 +988,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over weighted examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol) def test_multi_dim_weighted_train_create_loss(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" @@ -972,8 +1006,8 @@ class MultiLabelHead(test.TestCase): expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] # weights are reshaped to [2, 2, 1] to match logits. expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_training_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_training_loss = 9.9167 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, @@ -999,8 +1033,8 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -1088,11 +1122,11 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 keys = metric_keys.MetricKeys expected_metrics = { - keys.LOSS_MEAN: expected_loss / np.sum(weights), + keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, @@ -1128,8 +1162,8 @@ class PoissonRegressionHead(test.TestCase): # exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2), # exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)] # = [1.0, 3.020, 1.482] - # sum_loss = 5.502 - expected_loss = 5.502 + # training_loss = (1.0 + 3.020 + 1.482) / 3 + expected_loss = 1.834 atol = 0.001 expected_train_result = b'my_train_op' def _train_op_fn(loss): diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py index c63514eb688af48577f0a3b7ce9e7478309f2c30..c41996b9c6871d294f157411662f2eb9d4c09e5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear_test.py +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -42,7 +43,9 @@ def _linear_estimator_fn( """Returns a LinearEstimator that uses regression_head.""" return linear.LinearEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 0346ddc24bffd61068177f4622bd03be4acd53d9..bbbc19cc4dfb4b23f9b707023fbfdd124f1f48de 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -30,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -226,8 +228,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access weights=example_weights_by_head, processed_labels=labels_by_head) + # TODO(b/65403806): Support regularization_losses arg. def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): """See `_Head`.""" if isinstance(logits, dict): logits_dict = logits @@ -248,9 +252,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train( + all_estimator_spec=all_estimator_spec, + optimizer=optimizer, + train_op_fn=train_op_fn) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec @@ -279,16 +284,21 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): + def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): """Merges list of `EstimatorSpec` for training. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. - train_op_fn: Function to create train op. See `create_estimator_spec` - documentation for more details. + optimizer: `Optimizer` instance to create train op. See + `create_estimator_spec` documentation for more details. + train_op_fn: Function to create train op. Used if `optimizer` is `None`. Returns: `EstimatorSpec` that merges all heads for TRAIN. + + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode. """ losses = [] metrics = {} @@ -297,11 +307,20 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access # Metric keys already contain head.name. metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + loss, global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op_fn(loss), + train_op=train_op, eval_metric_ops=metrics) def _merge_predict(self, all_estimator_spec): @@ -319,6 +338,7 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access all_estimator_spec[0].export_outputs, self._heads[0].name), } + merged_predict_outputs = {} for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): @@ -327,8 +347,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access else: key = '%s/%s' % (k, head_name) export_outputs[key] = v + if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access + isinstance(v, export_output_lib.PredictOutput)): + for kp, vp in six.iteritems(v.outputs): + key = '%s/%s' % (head_name, kp) + merged_predict_outputs[key] = vp for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v + export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access + export_output_lib.PredictOutput(merged_predict_outputs)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 65ea89ba1b9236d0bf4d2de430fab168ef50bf97..d9e5aca2952d25a7d917f9d76f95ab89733115a0 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', + 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -158,6 +158,22 @@ class MultiHeadTest(test.TestCase): self.assertAllClose( expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict'].outputs['head1/probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict'].outputs['head2/probabilities'])) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict/head1'].outputs['probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict/head2'].outputs['probabilities'])) def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" @@ -181,8 +197,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1', + 'predict/head1', 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -238,8 +254,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', - 'head2', 'regression/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1', + 'predict/head1', 'head2', 'regression/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -283,10 +299,11 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15. expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 spec = multi_head.create_estimator_spec( @@ -300,8 +317,8 @@ class MultiHeadTest(test.TestCase): keys.LOSS + '/head1': expected_loss_head1, keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. - keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, + keys.LOSS_MEAN + '/head1': expected_loss_head1, + keys.LOSS_MEAN + '/head2': expected_loss_head2, # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, @@ -347,8 +364,8 @@ class MultiHeadTest(test.TestCase): tol = 1e-3 with self.test_session(): # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] - # (averaged over classes, sum-reduced over examples). - self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) + # (averaged over classes, averaged over examples). + self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): # Use different example weighting for each head weighting. @@ -383,18 +400,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -431,18 +448,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -466,14 +483,14 @@ class MultiHeadTest(test.TestCase): [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), } # Loss for the first head: - # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + - # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 - # = 28 + # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8 + # = 3.5 # Loss for the second head: - # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + - # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 - # = 74 - expected_training_loss = 28. + 74. + # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12 + # = 6.167 + expected_training_loss = 3.5 + 6.167 training_loss = multi_head.create_loss( features={}, @@ -495,8 +512,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -530,10 +547,46 @@ class MultiHeadTest(test.TestCase): _assert_simple_summaries(self, { metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) + def test_train_one_head_with_optimizer(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + multi_head = multi_head_lib.multi_head([head1]) + + logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} + labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -553,10 +606,12 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -592,9 +647,6 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, }, summary_str, tol) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index e0fae2c99292385c6dd32cc6002cee2076a2bb20..fa2697800ec1a44f215f3d5fc9be2197a9e58219 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -136,7 +136,7 @@ def replicate_model_fn(model_fn, the train_op argument of `EstimatorSpec`. loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This - argument can be used to replice only on the subset of available GPUs. + argument can be used to replicate only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index d46a18aacfcd911c56a9f22dc9581060c7b458a6..144b45982c8aec2e2b115c812b24e8843d60ce1e 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import re import shutil import tempfile +from absl.testing import parameterized import numpy as np import six @@ -57,26 +58,19 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import training -# TODO(isaprykin): Parametrize all the tests on -# replicate_model_fn._VariableDistributionMode when it's supported. -class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow_with_public_version(self): - return self._complete_flow_with_mode(mode=None) - - def test_complete_flow_with_mode_local_ps_server(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode. - SHARED_LOCAL_PARAMETER_SERVER) - - def test_complete_flow_with_mode_round_robin(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) - - def _complete_flow_with_mode(self, mode): + @parameterized.named_parameters( + ('PublicInterface', None), + ('ParameterServerMode', replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER), + ('RoundRobinMode', + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)) + def test_complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index c56c92a0a4a01218d1da5a6b366df3272d14b861..0a648d5d40e431bedb42017b15cabe078ac22fa7 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -66,6 +66,7 @@ tf_custom_op_py_library( "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -241,6 +242,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:training", "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -345,16 +347,3 @@ cuda_py_test( ], main = "python/kernel_tests/masked_matmul_benchmark.py", ) - -# All files -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index bbe842bd5ccc7357805adda1df42ba8799fcd8f2..363baa121ab3854a802ca3606e35597d31b35a57 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -21,14 +21,3 @@ tf_py_test( ], tags = ["notsan"], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index 44eab56011dad2f6fbe843b3569b4acc5c5e542a..ea8b9a17a27093cb57564861815edd6ecb18a014 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -67,14 +67,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index dd61f59585aee2e0245cfd6797b313b972c19bc5..2a6c97e8b9526894eba057505a2bf823ad778f56 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -353,7 +353,7 @@ class NearestNeighborsOp : public OpKernel { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); const int64 num_threads = worker_threads.num_threads; // This kernel might be configured to use fewer than the total number of - // available CPUs on the host machine. To avoid descructive interference + // available CPUs on the host machine. To avoid destructive interference // with other jobs running on the host machine, we must only use a fraction // of total available L3 cache. Unfortunately, we cannot query the host // machine to get the number of physical CPUs. So, we use a fixed per-CPU diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 23137e0a973c0bdd2cdbd97159f7fd310178bf54..84e80791f4991ad2b67d0a00ee1e00cf0d0daadc 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -41,11 +41,12 @@ from tensorflow.python.platform import resource_loader _clustering_ops = loader.load_op_library( resource_loader.get_path_to_datafile('_clustering_ops.so')) -# Euclidean distance between vectors U and V is defined as ||U - V||_F which is -# the square root of the sum of the absolute squares of the elements difference. +# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) +# which is the square root of the sum of the absolute squares of the elements +# difference. SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' # Cosine distance between vectors U and V is defined as -# 1 - (U \dot V) / (||U||_F ||V||_F) +# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' @@ -472,8 +473,8 @@ class KMeans(object): # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as - # x += (sum_i(d_i) - k * x) / (n + k). - # Compute sum_i(d_i), see comment above. + # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). + # Compute \\(sum_i(d_i)\\), see comment above. cluster_center_updates = math_ops.unsorted_segment_sum( inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 054888e734086c153f7af59f4548d4d20abab813..811fa89bc38c61b16710a441b99d9e5dfac67668 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -51,9 +51,9 @@ class WALSModel(object): r"""A model for Weighted Alternating Least Squares matrix factorization. It minimizes the following loss function over U, V: - \\( - \|\sqrt W \odot (A - U V^T) \|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) - )\\ + $$ + \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) + $$ where, A: input matrix, W: weight matrix. Note that the (element-wise) square root of the weights @@ -61,12 +61,12 @@ class WALSModel(object): U, V: row_factors and column_factors matrices, \\(\lambda)\\: regularization. Also we assume that W is of the following special form: - \\( W_{ij} = W_0 + R_i * C_j )\\ if \\(A_{ij} \ne 0)\\, - \\(W_{ij} = W_0)\\ otherwise. + \\( W_{ij} = W_0 + R_i * C_j \\) if \\(A_{ij} \ne 0\\), + \\(W_{ij} = W_0\\) otherwise. where, - \\(W_0)\\: unobserved_weight, - \\(R_i)\\: row_weights, - \\(C_j)\\: col_weights. + \\(W_0\\): unobserved_weight, + \\(R_i\\): row_weights, + \\(C_j\\): col_weights. Note that the current implementation supports two operation modes: The default mode is for the condition where row_factors and col_factors can individually @@ -82,14 +82,15 @@ class WALSModel(object): normalized as follows: _, _, unregularized_loss, regularization, sum_weights = update_row_factors(sp_input) - if sp_input contains the rows {A_i, i \in I}, and the input matrix A has n - total rows, then the minibatch loss = unregularized_loss + regularization is - \\( + if sp_input contains the rows \\({A_i, i \in I}\\), and the input matrix A + has n total rows, then the minibatch loss = unregularized_loss + + regularization is + $$ (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2 - )\\ + $$ The sum_weights tensor contains the normalized sum of weights - sum(W_I) * n / |I|. + \\(sum(W_I) * n / |I|\\). A typical usage example (pseudocode): @@ -106,7 +107,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per interation sweep before the row(column) update + # To be run once per integration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -118,9 +119,9 @@ class WALSModel(object): init_row_update_op = model.initialize_row_update_op init_col_update_op = model.initialize_col_update_op - # Ops to upate row(column). This can either take the entire sparse tensor - # or slices of sparse tensor. For distributed trainer, each trainer - # handles just part of the matrix. + # Ops to update row(column). This can either take the entire sparse + # tensor or slices of sparse tensor. For distributed trainer, each + # trainer handles just part of the matrix. _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors( sp_input=matrix_slices_from_queue_for_worker_shard) row_loss = unreg_row_loss + row_reg @@ -220,10 +221,10 @@ class WALSModel(object): in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of inner lists matching the number of row factor shards and the elements in each inner list are the weights for the rows of the corresponding row - factor shard. In this case, w_ij = unonbserved_weight + + factor shard. In this case, w_ij = unobserved_weight + row_weights[i] * col_weights[j]. - If this is a single non-negative real number, this value is used for - all row weights and w_ij = unobserved_weight + row_weights * + all row weights and \\(w_ij\\) = unobserved_weight + row_weights * col_weights[j]. Note that it is allowed to have row_weights as a list while col_weights a single number or vice versa. @@ -435,7 +436,7 @@ class WALSModel(object): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calcuated value from the factors. + A op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: @@ -564,7 +565,7 @@ class WALSModel(object): Note that specifically this initializes the cache of the row and column weights on workers when `use_factors_weights_cache` is True. In this case, - if these weights are being calcualted and reset after the object is created, + if these weights are being calculated and reset after the object is created, it is important to ensure this ops is run afterwards so the cache reflects the correct values. """ @@ -665,18 +666,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the rows {A_{i, :}, i \in I}, and the input - matrix A has n total rows, then the unregularized loss is: - (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I| + term. If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the + input matrix A has n total rows, then the unregularized loss is: + \\(\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the rows {A_{i, :}, i \in I}, and the input matrix - A has n total rows, then the regularization term is: - \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2. + If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the input + matrix A has n total rows, then the regularization term is: + \\(\lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2\\). sum_weights: The sum of the weights W_I corresponding to sp_input, - normalized by a factor of n / |I|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(n / |I|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( True, sp_input=sp_input, transpose_input=transpose_input) @@ -698,18 +699,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the columns {A_{:, j}, j \in J}, and the - input matrix A has m total columns, then the unregularized loss is: - (\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I| + term. If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and + the input matrix A has m total columns, then the unregularized loss is: + \\(\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the columns {A_{:, j}, j \in J}, and the input - matrix A has m total columns, then the regularization term is: - \lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2. + If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and the + input matrix A has m total columns, then the regularization term is: + \\(\lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2\\). sum_weights: The sum of the weights W_J corresponding to sp_input, - normalized by a factor of m / |J|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(m / |J|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( False, sp_input=sp_input, transpose_input=transpose_input) @@ -720,8 +721,8 @@ class WALSModel(object): projection_weights=None): """Projects the row factors. - This computes the row embedding u_i for an observed row a_i by solving - one iteration of the update equations. + This computes the row embedding \\(u_i\\) for an observed row \\(a_i\\) by + solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of rows. Please note that the @@ -753,8 +754,8 @@ class WALSModel(object): projection_weights=None): """Projects the column factors. - This computes the column embedding v_j for an observed column a_j by solving - one iteration of the update equations. + This computes the column embedding \\(v_j\\) for an observed column + \\(a_j\\) by solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of columns. Please note that @@ -938,7 +939,7 @@ class WALSModel(object): loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by - # computing the product for (i, j) in loss_sp_input.indices. + # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index c8137339155ef1da8ee53967eea84a550f12ecbc..bb5140aeb3bf0238ca7cb52067ea6328dd1736d5 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -210,7 +210,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -283,8 +283,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 3 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, @@ -385,7 +385,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -462,8 +462,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 2 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 98d6434f4752b224201e38bed05ccd14428a758b..5d77bc77e124378e13667673e4e841c0a1135b31 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -280,7 +280,7 @@ class GmmAlgorithm(object): self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): - """Defines the full covariance probabilties per example in a class. + """Defines the full covariance probabilities per example in a class. Updates a matrix with dimension num_examples X num_classes. @@ -344,7 +344,7 @@ class GmmAlgorithm(object): def _define_prior_log_prob_operation(self, shard_id): """Computes the prior probability of all samples. - Updates a vector where each item is the prior probabibility of an + Updates a vector where each item is the prior probability of an input example. Args: @@ -357,8 +357,8 @@ class GmmAlgorithm(object): # Shape broadcasting. probs = array_ops.expand_dims(self._probs[shard_id], 0) # Membership weights are computed as: - # w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)} - # {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)} + # $$w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}$$ + # $$ {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}$$ # where "i" is the i-th example, "k" is the k-th mixture, theta are # the model parameters and y_i the observations. # These are defined for each shard. diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 00a4734eb6d89cd02484f1c5161366377cc71208..4fc9c96e9d0a317ef757d5e1bb6563ed7c8832af 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -210,7 +210,7 @@ class GMMTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 7319eaa7de8db8e4677bdf64af3b0a72c1007a90..bfe338c9f9a7b761cfcd627b92f1682af97630c9 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,6 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -105,24 +106,32 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): logging.info(e) -def _parse_tensor_or_dict(features): +def _parse_features_if_necessary(features, feature_columns): """Helper function to convert the input points into a usable format. Args: - features: The input points. + features: The input features. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column instances + that can be passed to `tf.feature_column.input_layer`. If this is None, + all features will be used. Returns: - If `features` is a dict of `k` features, each of which is a vector of `n` - scalars, the return value is a Tensor of shape `(n, k)` representing `n` - input points, where the items in the `k` dimension are sorted - lexicographically by `features` key. If `features` is not a dict, it is - returned unmodified. + If `features` is a dict of `k` features (optionally filtered by + `feature_columns`), each of which is a vector of `n` scalars, the return + value is a Tensor of shape `(n, k)` representing `n` input points, where the + items in the `k` dimension are sorted lexicographically by `features` key. + If `features` is not a dict, it is returned unmodified. """ - if isinstance(features, dict): - keys = sorted(features.keys()) - with ops.colocate_with(features[keys[0]]): - features = array_ops.concat([features[k] for k in keys], axis=1) - return features + if not isinstance(features, dict): + return features + + if feature_columns: + return fc.input_layer(features, feature_columns) + + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + return array_ops.concat([features[k] for k in keys], axis=1) class _ModelFn(object): @@ -130,7 +139,8 @@ class _ModelFn(object): def __init__(self, num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance): + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns): self._num_clusters = num_clusters self._initial_clusters = initial_clusters self._distance_metric = distance_metric @@ -139,6 +149,7 @@ class _ModelFn(object): self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._relative_tolerance = relative_tolerance + self._feature_columns = feature_columns def model_fn(self, features, mode, config): """Model function for the estimator. @@ -166,7 +177,7 @@ class _ModelFn(object): # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a # single item. - input_points = _parse_tensor_or_dict(features) + input_points = _parse_features_if_necessary(features, self._feature_columns) # Let N = the number of input_points. # all_distances: A list of one matrix of shape (N, num_clusters). Each value @@ -316,7 +327,8 @@ class KMeansClustering(estimator.Estimator): mini_batch_steps_per_iteration=1, kmeans_plus_plus_num_retries=2, relative_tolerance=None, - config=None): + config=None, + feature_columns=None): """Creates an Estimator for running KMeans training and inference. This Estimator implements the following variants of the K-means algorithm: @@ -362,11 +374,11 @@ class KMeansClustering(estimator.Estimator): than `num_clusters`, a TensorFlow runtime error occurs. distance_metric: The distance metric used for clustering. One of: * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance - between vectors `u` and `v` is defined as `||u - v||_2` which is - the square root of the sum of the absolute squares of the elements' - difference. + between vectors `u` and `v` is defined as `\\(||u - v||_2\\)` + which is the square root of the sum of the absolute squares of + the elements' difference. * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors - `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + `u` and `v` is defined as `\\(1 - (u . v) / (||u||_2 ||v||_2)\\)`. random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above. @@ -383,6 +395,10 @@ class KMeansClustering(estimator.Estimator): iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. config: See @{tf.estimator.Estimator}. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column + instances that can be passed to `tf.feature_column.input_layer`. If this + is None, all features will be used. Raises: ValueError: An invalid argument was passed to `initial_clusters` or @@ -402,7 +418,8 @@ class KMeansClustering(estimator.Estimator): model_fn=_ModelFn( num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns).model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95..88eb9cf692992fe2e1fc4f060ac98dd721c22307 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,6 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -226,6 +227,44 @@ class KMeansTest(KMeansTestBase): self._infer_helper(kmeans, clusters, 10) self._infer_helper(kmeans, clusters, 1) + def _parse_feature_dict_helper(self, features, parsed_feature_dict): + # Perform a sanity check. + self.assertEqual(features.shape, parsed_feature_dict.shape) + self.assertEqual(features.dtype, parsed_feature_dict.dtype) + # Then check that running the tensor yields the original list of points. + with self.test_session() as sess: + parsed_points = sess.run(parsed_feature_dict) + self.assertAllEqual(self.points, parsed_points) + + def test_parse_features(self): + """Tests the various behaviours of kmeans._parse_features_if_necessary.""" + + # No-op if a tensor is passed in. + features = constant_op.constant(self.points) + parsed_features = kmeans_lib._parse_features_if_necessary(features, None) + self.assertAllEqual(features, parsed_features) + + # All values from a feature dict are transformed into a tensor. + feature_dict = { + 'x': [[point[0]] for point in self.points], + 'y': [[point[1]] for point in self.points] + } + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict, None) + self._parse_feature_dict_helper(features, parsed_feature_dict) + + # Only the feature_columns of a feature dict are transformed into a tensor. + feature_dict_with_extras = { + 'foo': 'bar', + 'x': [[point[0]] for point in self.points], + 'baz': {'fizz': 'buzz'}, + 'y': [[point[1]] for point in self.points] + } + feature_columns = [fc.numeric_column(key='x'), fc.numeric_column(key='y')] + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict_with_extras, feature_columns) + self._parse_feature_dict_helper(features, parsed_feature_dict) + class KMeansTestMultiStageInit(KMeansTestBase): @@ -374,7 +413,7 @@ class KMeansCosineDistanceTest(KMeansTestBase): self.assertAllClose(score, self.true_score, atol=1e-2) def test_predict_kmeans_plus_plus(self): - # Most points are concetrated near one center. KMeans++ is likely to find + # Most points are concentrated near one center. KMeans++ is likely to find # the less populated centers. points = np.array( [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], @@ -394,7 +433,6 @@ class KMeansCosineDistanceTest(KMeansTestBase): true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_score = len(points) - np.tensordot( normalize(points), true_centers[true_assignments]) - kmeans = kmeans_lib.KMeansClustering( 3, initial_clusters=self.initial_clusters, @@ -566,7 +604,7 @@ class KMeansTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9..ca46c39baa16a7fddb96121e0402fc35d24ce1c2 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -216,7 +216,7 @@ def _wals_factorization_model_function(features, labels, mode, params): name=WALSMatrixFactorization.LOSS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + # \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\) rwse_var = variable_scope.variable( 0., trainable=False, @@ -235,7 +235,7 @@ def _wals_factorization_model_function(features, labels, mode, params): num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of - * new_factors: A flot Tensor of the factor values after update. + * new_factors: A float Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. @@ -490,11 +490,11 @@ class WALSMatrixFactorization(estimator.Estimator): and the problem simplifies to ALS. Note that, in this case, col_weights must also be set to "None". - List of lists of non-negative scalars, of the form - [[w_0, w_1, ...], [w_k, ... ], [...]], + \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\), where the number of inner lists equal to the number of row factor shards and the elements in each inner list are the weights for the rows of that shard. In this case, - w_ij = unonbserved_weight + row_weights[i] * col_weights[j]. + \\(w_ij = unonbserved_weight + row_weights[i] * col_weights[j]\\). - A non-negative scalar: This value is used for all row weights. Note that it is allowed to have row_weights as a list and col_weights as a scalar, or vice-versa. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 3614b2b15a6cbdd73f9f24c7e4e4534228d31499..aab7d0c9e8874269bfa5f33193b0dc0ba4bbc9cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "feature_column_py", srcs = ["__init__.py"], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index e60116966fc8d8bb0745f50a0238f10f02af4167..555beddeaab419bcb23d06f960d370b706d744c8 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -166,6 +166,10 @@ def sequence_categorical_column_with_identity( Returns: A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc._SequenceCategoricalColumn( fc.categorical_column_with_identity( @@ -205,6 +209,10 @@ def sequence_categorical_column_with_hash_bucket( Returns: A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( fc.categorical_column_with_hash_bucket( @@ -257,6 +265,13 @@ def sequence_categorical_column_with_vocabulary_file( Returns: A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( fc.categorical_column_with_vocabulary_file( @@ -311,6 +326,12 @@ def sequence_categorical_column_with_vocabulary_list( Returns: A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. """ return fc._SequenceCategoricalColumn( fc.categorical_column_with_vocabulary_list( @@ -352,8 +373,17 @@ def sequence_numeric_column( Returns: A `_SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. """ - # TODO(b/73160931): Add validations. + shape = fc._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + return _SequenceNumericColumn( key, shape=shape, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index b64f086376dad65c1f32bee4bfce9334a60fd24a..88f5d535162939e063eb1e7f43d495137c5adef4 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -662,6 +662,32 @@ class SequenceIndicatorColumnTest(test.TestCase): class SequenceNumericColumnTest(test.TestCase): + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py deleted file mode 100644 index 4ed7268e7a921284eed7767d870e56ecac39a3b1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Experimental methods for tf.feature_column sequence input.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -import abc -import collections - - -from tensorflow.python.feature_column import feature_column as fc -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import variable_scope - -# TODO(b/73160931): Fix pydoc. -# pylint: disable=g-doc-args,missing-docstring,protected-access -# TODO(b/73827486): Support SequenceExample. - - -def sequence_input_layer( - features, - feature_columns, - weight_collections=None, - trainable=True, - scope=None): - """"Builds input layer for sequence input. - - All `feature_columns` must be sequence dense columns with the same - `sequence_length`. The output of this method can be fed into sequence - networks, such as RNN. - - The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ from - batch to batch. - - If multiple `feature_columns` are given with `Di` `num_elements` each, their - outputs are concatenated. So, the final `Tensor` has shape - `[batch_size, T, D0 + D1 + ... + Dn]`. - - Example: - - ```python - rating = sequence_numeric_column('rating') - watches = sequence_categorical_column_with_identity( - 'watches', num_buckets=1000) - watches_embedding = embedding_column(watches, dimension=10) - columns = [rating, watches] - - features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) - - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) - ``` - - Returns: - An `(input_layer, sequence_length)` tuple where: - - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ - from batch to batch. `D` is the sum of `num_elements` for all - `feature_columns`. - - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence - length for each example. - Raises: - ValueError: If any of the `feature_columns` is the wrong type. - """ - feature_columns = fc._clean_feature_columns(feature_columns) - for c in feature_columns: - if not isinstance(c, _SequenceDenseColumn): - raise ValueError( - 'All feature_columns must be of type _SequenceDenseColumn. ' - 'Given (type {}): {}'.format(type(c), c)) - - with variable_scope.variable_scope( - scope, default_name='sequence_input_layer', values=features.values()): - builder = fc._LazyBuilder(features) - output_tensors = [] - sequence_lengths = [] - ordered_columns = [] - for column in sorted(feature_columns, key=lambda x: x.name): - ordered_columns.append(column) - with variable_scope.variable_scope( - None, default_name=column._var_scope_name): - dense_tensor, sequence_length = column._get_sequence_dense_tensor( - builder, - weight_collections=weight_collections, - trainable=trainable) - # Flattens the final dimension to produce a 3D Tensor. - num_elements = column._variable_shape.num_elements() - shape = array_ops.shape(dense_tensor) - output_tensors.append( - array_ops.reshape( - dense_tensor, - shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) - sequence_lengths.append(sequence_length) - fc._verify_static_batch_size_equality(output_tensors, ordered_columns) - # TODO(b/73160931): Verify sequence_length equality. - return array_ops.concat(output_tensors, -1), sequence_lengths[0] - - -# TODO(b/73160931): Add remaining categorical columns. -def sequence_categorical_column_with_identity( - key, num_buckets, default_value=None): - return _SequenceCategoricalColumn( - fc.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) - - -# TODO(b/73160931): Merge with embedding_column -def _sequence_embedding_column( - categorical_column, dimension, initializer=None, ckpt_to_load_from=None, - tensor_name_in_ckpt=None, max_norm=None, trainable=True): - if not isinstance(categorical_column, _SequenceCategoricalColumn): - raise ValueError( - 'categorical_column must be of type _SequenceCategoricalColumn. ' - 'Given (type {}): {}'.format( - type(categorical_column), categorical_column)) - return _SequenceEmbeddingColumn( - fc.embedding_column( - categorical_column, - dimension=dimension, - initializer=initializer, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable)) - - -def sequence_numeric_column( - key, - shape=(1,), - default_value=0., - dtype=dtypes.float32): - # TODO(b/73160931): Add validations. - return _SequenceNumericColumn( - key, - shape=shape, - default_value=default_value, - dtype=dtype) - - -class _SequenceDenseColumn(fc._FeatureColumn): - """Represents dense sequence data.""" - - __metaclass__ = abc.ABCMeta - - TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name - 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) - - @abc.abstractproperty - def _variable_shape(self): - """`TensorShape` without batch and sequence dimensions.""" - pass - - @abc.abstractmethod - def _get_sequence_dense_tensor( - self, inputs, weight_collections=None, trainable=None): - """Returns a `TensorSequenceLengthPair`.""" - pass - - -def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): - with ops.name_scope(None, 'sequence_length') as name_scope: - row_ids = sp_tensor.indices[:, 0] - column_ids = sp_tensor.indices[:, 1] - column_ids += array_ops.ones_like(column_ids) - seq_length = ( - math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) - # If the last n rows do not have ids, seq_length will have shape - # [batch_size - n]. Pad the remaining values with zeros. - n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] - padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) - return array_ops.concat([seq_length, padding], axis=0, name=name_scope) - - -class _SequenceCategoricalColumn( - fc._CategoricalColumn, - collections.namedtuple( - '_SequenceCategoricalColumn', ['categorical_column'])): - - @property - def name(self): - return self.categorical_column.name - - @property - def _parse_example_spec(self): - return self.categorical_column._parse_example_spec - - def _transform_feature(self, inputs): - return self.categorical_column._transform_feature(inputs) - - @property - def _num_buckets(self): - return self.categorical_column._num_buckets - - def _get_sparse_tensors(self, inputs, weight_collections=None, - trainable=None): - sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) - id_tensor = sparse_tensors.id_tensor - weight_tensor = sparse_tensors.weight_tensor - # Expands final dimension, so that embeddings are not combined during - # embedding lookup. - check_id_rank = check_ops.assert_equal( - array_ops.rank(id_tensor), 2, - data=[ - 'Column {} expected ID tensor of rank 2. '.format(self.name), - 'id_tensor shape: ', array_ops.shape(id_tensor)]) - with ops.control_dependencies([check_id_rank]): - id_tensor = sparse_ops.sparse_reshape( - id_tensor, - shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) - if weight_tensor is not None: - check_weight_rank = check_ops.assert_equal( - array_ops.rank(weight_tensor), 2, - data=[ - 'Column {} expected weight tensor of rank 2.'.format(self.name), - 'weight_tensor shape:', array_ops.shape(weight_tensor)]) - with ops.control_dependencies([check_weight_rank]): - weight_tensor = sparse_ops.sparse_reshape( - weight_tensor, - shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) - return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) - - def _sequence_length(self, inputs): - sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) - return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor) - - -class _SequenceEmbeddingColumn( - _SequenceDenseColumn, - collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])): - - @property - def name(self): - return self.embedding_column.name - - @property - def _parse_example_spec(self): - return self.embedding_column._parse_example_spec - - def _transform_feature(self, inputs): - return self.embedding_column._transform_feature(inputs) - - @property - def _variable_shape(self): - return self.embedding_column._variable_shape - - def _get_sequence_dense_tensor( - self, inputs, weight_collections=None, trainable=None): - dense_tensor = self.embedding_column._get_dense_tensor( - inputs=inputs, - weight_collections=weight_collections, - trainable=trainable) - sequence_length = self.embedding_column.categorical_column._sequence_length( - inputs) - return _SequenceDenseColumn.TensorSequenceLengthPair( - dense_tensor=dense_tensor, sequence_length=sequence_length) - - -class _SequenceNumericColumn( - _SequenceDenseColumn, - collections.namedtuple( - '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): - - @property - def name(self): - return self.key - - @property - def _parse_example_spec(self): - return {self.key: parsing_ops.VarLenFeature(self.dtype)} - - def _transform_feature(self, inputs): - return inputs.get(self.key) - - @property - def _variable_shape(self): - return tensor_shape.TensorShape(self.shape) - - def _get_sequence_dense_tensor( - self, inputs, weight_collections=None, trainable=None): - # Do nothing with weight_collections and trainable since no variables are - # created in this function. - del weight_collections - del trainable - sp_tensor = inputs.get(self) - dense_tensor = sparse_ops.sparse_tensor_to_dense( - sp_tensor, default_value=self.default_value) - # Reshape into [batch_size, T, variable_shape]. - dense_shape = array_ops.concat( - [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], - axis=0) - dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) - sequence_length = _sequence_length_from_sparse_tensor( - sp_tensor, num_elements=self._variable_shape.num_elements()) - return _SequenceDenseColumn.TensorSequenceLengthPair( - dense_tensor=dense_tensor, sequence_length=sequence_length) - -# pylint: enable=g-doc-args,missing-docstring,protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py deleted file mode 100644 index 59674869a27c3a40ab9cb3dcede384d1cda7ce27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for sequential_feature_column.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc -from tensorflow.python.feature_column.feature_column import _LazyBuilder -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.platform import test -from tensorflow.python.training import monitored_session - - -class SequenceInputLayerTest(test.TestCase): - - def test_embedding_column(self): - vocabulary_size = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [1] - # example 1, ids [2, 0] - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - - embedding_dimension_a = 2 - embedding_values_a = ( - (1., 2.), # id 0 - (3., 4.), # id 1 - (5., 6.) # id 2 - ) - embedding_dimension_b = 3 - embedding_values_b = ( - (11., 12., 13.), # id 0 - (14., 15., 16.), # id 1 - (17., 18., 19.) # id 2 - ) - def _get_initializer(embedding_dimension, embedding_values): - def _initializer(shape, dtype, partition_info): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - return _initializer - - expected_input_layer = [ - # example 0, ids_a [2], ids_b [1] - [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], - # example 1, ids_a [0, 1], ids_b [2, 0] - [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], - ] - expected_sequence_length = [1, 2] - - categorical_column_a = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column_a = sfc._sequence_embedding_column( - categorical_column_a, dimension=embedding_dimension_a, - initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) - categorical_column_b = sfc.sequence_categorical_column_with_identity( - key='bbb', num_buckets=vocabulary_size) - embedding_column_b = sfc._sequence_embedding_column( - categorical_column_b, dimension=embedding_dimension_b, - initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) - - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - # Test that columns are reordered alphabetically. - feature_columns=[embedding_column_b, embedding_column_a]) - - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('sequence_input_layer/aaa_embedding/embedding_weights:0', - 'sequence_input_layer/bbb_embedding/embedding_weights:0'), - tuple([v.name for v in global_vars])) - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) - self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) - self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_numeric_column(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_input_layer = [ - [[0.], [1.]], - [[10.], [0.]], - ] - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa') - - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_numeric_column_multi_dim(self): - """Tests sequence_input_layer for multi-dimensional numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - # The output of numeric_column._get_dense_tensor should be flattened. - expected_input_layer = [ - [[0., 1., 2., 3.], [4., 5., 6., 7.]], - [[10., 11., 12., 13.], [0., 0., 0., 0.]], - ] - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - -def _assert_sparse_tensor_value(test_case, expected, actual): - test_case.assertEqual(np.int64, np.array(actual.indices).dtype) - test_case.assertAllEqual(expected.indices, actual.indices) - - test_case.assertEqual( - np.array(expected.values).dtype, np.array(actual.values).dtype) - test_case.assertAllEqual(expected.values, actual.values) - - test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) - test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) - - -class SequenceCategoricalColumnWithIdentityTest(test.TestCase): - - def test_get_sparse_tensors(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - expected_sparse_ids = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=np.array((1, 2, 0), dtype=np.int64), - dense_shape=(2, 2, 1)) - - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) - - self.assertIsNone(id_weight_pair.weight_tensor) - with monitored_session.MonitoredSession() as sess: - _assert_sparse_tensor_value( - self, - expected_sparse_ids, - id_weight_pair.id_tensor.eval(session=sess)) - - def test_get_sparse_tensors_inputs3d(self): - """Tests _get_sparse_tensors when the input is already 3D Tensor.""" - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), - values=(1, 2, 0), - dense_shape=(2, 2, 1)) - - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r'Column aaa expected ID tensor of rank 2\.\s*' - r'id_tensor shape:\s*\[2 2 1\]'): - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': inputs})) - with monitored_session.MonitoredSession() as sess: - id_weight_pair.id_tensor.eval(session=sess) - - def test_sequence_length(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(1, 2, 0), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] - - sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_sequence_length_with_zeros(self): - column = sfc.sequence_categorical_column_with_identity( - 'aaa', num_buckets=3) - inputs = sparse_tensor.SparseTensorValue( - indices=((1, 0), (3, 0), (3, 1)), - values=(1, 2, 0), - dense_shape=(5, 2)) - expected_sequence_length = [0, 1, 0, 2, 0] - - sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - -class SequenceEmbeddingColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 1), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 2)) - - embedding_dimension = 2 - embedding_values = ( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ) - def _initializer(shape, dtype, partition_info): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - - expected_lookups = [ - # example 0, ids [2] - [[7., 11.], [0., 0.]], - # example 1, ids [0, 1] - [[1., 2.], [3., 5.]], - # example 2, ids [] - [[0., 0.], [0., 0.]], - # example 3, ids [1] - [[3., 5.], [0., 0.]], - ] - - categorical_column = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = sfc._sequence_embedding_column( - categorical_column, dimension=embedding_dimension, - initializer=_initializer) - - embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('embedding_weights:0',), tuple([v.name for v in global_vars])) - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) - self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) - - def test_sequence_length(self): - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - expected_sequence_length = [1, 2] - - categorical_column = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = sfc._sequence_embedding_column( - categorical_column, dimension=2) - - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_sequence_length_with_empty_rows(self): - """Tests _sequence_length when some examples do not have ids.""" - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [] - # example 1, ids [2] - # example 2, ids [0, 1] - # example 3, ids [] - # example 4, ids [1] - # example 5, ids [] - indices=((1, 0), (2, 0), (2, 1), (4, 0)), - values=(2, 0, 1, 1), - dense_shape=(6, 2)) - expected_sequence_length = [0, 1, 2, 0, 1, 0] - - categorical_column = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = sfc._sequence_embedding_column( - categorical_column, dimension=2) - - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - -class SequenceNumericColumnTest(test.TestCase): - - def test_get_sequence_dense_tensor(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_dense_tensor = [ - [[0.], [1.]], - [[10.], [0.]], - ] - numeric_column = sfc.sequence_numeric_column('aaa') - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_get_sequence_dense_tensor_with_shape(self): - """Tests get_sequence_dense_tensor with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_dense_tensor = [ - [[0., 1., 2.], [3., 4., 5.]], - [[10., 11., 12.], [0., 0., 0.]], - ] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_get_dense_tensor_multi_dim(self): - """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] - # example 1, [[[10., 11.], [12., 13.]]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), - (1, 0), (1, 1), (1, 2), (1, 3)), - values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), - dense_shape=(2, 8)) - expected_dense_tensor = [ - [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], - [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], - ] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - - dense_tensor, _ = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_dense_tensor, dense_tensor.eval(session=sess)) - - def test_sequence_length(self): - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0., 1., 2.], [3., 4., 5.]] - # example 1, [[10., 11., 12.]] - indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), - (1, 0), (1, 1), (1, 2)), - values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), - dense_shape=(2, 6)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_sequence_length_with_shape(self): - """Tests _sequence_length with shape !=(1,).""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [[0.], [1]] - # example 1, [[10.]] - indices=((0, 0), (0, 1), (1, 0)), - values=(0., 1., 10.), - dense_shape=(2, 2)) - expected_sequence_length = [2, 1] - numeric_column = sfc.sequence_numeric_column('aaa') - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - def test_sequence_length_with_empty_rows(self): - """Tests _sequence_length when some examples do not have ids.""" - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, values [] - # example 1, values [[0.], [1.]] - # example 2, [[2.]] - # example 3, values [] - # example 4, [[3.]] - # example 5, values [] - indices=((1, 0), (1, 1), (2, 0), (4, 0)), - values=(0., 1., 2., 3.), - dense_shape=(6, 2)) - expected_sequence_length = [0, 2, 1, 0, 1, 0] - numeric_column = sfc.sequence_numeric_column('aaa') - - _, sequence_length = numeric_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) - - with monitored_session.MonitoredSession() as sess: - self.assertAllEqual( - expected_sequence_length, sequence_length.eval(session=sess)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index eccce99071dc1477cf4f3bb152f3304b3b0fc35a..f7b3273a4d35eadb9fad49399b7bf18d4bd33503 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -180,15 +180,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 6b455567d766dbe6d380a498bd7f521db27e077b..59bad8982dd163f89f37e1a0a9d5017d0c495de3 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -74,15 +74,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index ac043fda0638e61f422e769ab3047a53a1b377bd..b1c8ad49eaf8d2400e431fcf4820fca6e0314557 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -321,15 +321,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 3398b3fd1c1036091bfadf548f7d44dbf9eb1046..cbb68bd3eb257f9472515e5c29ce4f02057be321 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -83,6 +83,7 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@argsort @@py_func @@sort diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 3cad1fee1984042e3a9ab91a0af70cbaca25cece..5b150339953f961c756c0909dd1795341159b9cd 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -68,7 +68,7 @@ from tensorflow.python.util import tf_decorator __all__ = [ 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', - 'arg_scoped_arguments' + 'arg_scoped_arguments', 'arg_scope_func_key' ] _ARGSTACK = [{}] @@ -89,7 +89,7 @@ def current_arg_scope(): return stack[-1] -def _key_op(op): +def arg_scope_func_key(op): return getattr(op, '_key_op', str(op)) @@ -103,9 +103,9 @@ def _kwarg_names(func): def _add_op(op): - key_op = _key_op(op) - if key_op not in _DECORATED_OPS: - _DECORATED_OPS[key_op] = _kwarg_names(op) + key = arg_scope_func_key(op) + if key not in _DECORATED_OPS: + _DECORATED_OPS[key] = _kwarg_names(op) @tf_contextlib.contextmanager @@ -147,16 +147,16 @@ def arg_scope(list_ops_or_scope, **kwargs): try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: - key_op = _key_op(op) + key = arg_scope_func_key(op) if not has_arg_scope(op): raise ValueError('%s is not decorated with @add_arg_scope', _name_op(op)) - if key_op in current_scope: - current_kwargs = current_scope[key_op].copy() + if key in current_scope: + current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) - current_scope[key_op] = current_kwargs + current_scope[key] = current_kwargs else: - current_scope[key_op] = kwargs.copy() + current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: @@ -176,14 +176,14 @@ def add_arg_scope(func): def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs - key_func = _key_op(func) + key_func = arg_scope_func_key(func) if key_func in current_scope: current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) _add_op(func) - setattr(func_with_args, '_key_op', _key_op(func)) + setattr(func_with_args, '_key_op', arg_scope_func_key(func)) return tf_decorator.make_decorator(func, func_with_args) @@ -196,7 +196,7 @@ def has_arg_scope(func): Returns: a boolean. """ - return _key_op(func) in _DECORATED_OPS + return arg_scope_func_key(func) in _DECORATED_OPS def arg_scoped_arguments(func): @@ -209,4 +209,4 @@ def arg_scoped_arguments(func): a list of kwargs names. """ assert has_arg_scope(func) - return _DECORATED_OPS[_key_op(func)] + return _DECORATED_OPS[arg_scope_func_key(func)] diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 7ba9d4ffa90f6860629b15a2ea91e0c573bf6368..4c3879d4fc08b53ea8be5f1256a830a64fb39af6 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -170,6 +170,30 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func1_args) self.assertDictEqual(kwargs, func1_kwargs) + def testNestedArgScopeObjectCreatedOutsideScopeOverridesArgScope(self): + + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + scope_object = get_scope_object() + with arg_scope([func1], b=2, d=10): + with arg_scope(scope_object): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1]}) + + def testArgScopeObjectCreatedWithinScopeInheritsArgScope(self): + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + with arg_scope([func1], b=2, d=10): + with arg_scope(get_scope_object()): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1], 'd': 10}) + def testSharedArgScope(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index cc19372acf956371c2d029c7b8eb5534c3789413..bd764ed57a6da0a4d356235108e998a80ac34362 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -24,10 +24,8 @@ import collections # from tensorflow.core.protobuf import critical_section_pb2 from tensorflow.python.eager import context -from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -48,6 +46,26 @@ class _ExecutionSignature( pass +def _identity(x): + """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" + if isinstance(x, tensor_array_ops.TensorArray): + return x.identity() + elif isinstance(x, ops.Operation): + return control_flow_ops.group(x) + elif context.executing_eagerly() and x is None: + return None + else: + return array_ops.identity(x) + + +def _get_colocation(op): + """Get colocation symbol from op, if any.""" + try: + return op.get_attr("_class") + except ValueError: + return None + + class CriticalSection(object): """Critical section. @@ -180,8 +198,8 @@ class CriticalSection(object): The tensors returned from `fn(*args, **kwargs)`. Raises: - ValueError: If `fn` attempts to use this `CriticalSection` in any nested - way. + ValueError: If `fn` attempts to lock this `CriticalSection` in any nested + or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same resources as in `*args`, `**kwargs`, and any additionaly captured @@ -193,69 +211,52 @@ class CriticalSection(object): exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) with ops.name_scope(name, "critical_section_execute", []): - lock = gen_resource_variable_ops.mutex_lock(self._handle) - - with ops.control_dependencies([lock]): - c_known_ops = set() - c_captured_tensors = set() - def add_op_internal(op): - c_known_ops.add(op) - for i in op.inputs: - if i.op not in c_known_ops: - c_captured_tensors.add(i) + # Ensure that mutex locking only happens *after* all args and + # kwargs have been executed. This avoids certain types of deadlocks. + lock = gen_resource_variable_ops.mutex_lock(self._handle) - c = function.HelperContext(add_op_internal) - with c: + if not context.executing_eagerly(): + # NOTE(ebrevdo): This is to ensure we don't pick up spurious + # Operations created by other threads. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + existing_ops = ops.get_default_graph().get_operations() + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + # TODO(ebrevdo): If creating critical sections in a python loop, this + # makes graph creation time quadratic. Revisit if this + # becomes a problem. + created_ops = (set(ops.get_default_graph().get_operations()) + .difference(existing_ops)) + else: + with ops.control_dependencies([lock]): r = fn(*args, **kwargs) - resource_inputs = set([ - x for x in - list(nest.flatten(args)) + nest.flatten(kwargs.values()) + - list(c_captured_tensors) - if tensor_util.is_tensor(x) and x.dtype == dtypes.resource]) - - if self._handle in resource_inputs: - raise ValueError("The function fn attempts to access the " - "CriticalSection in which it would be running. " - "This is illegal and would cause deadlocks. " - "CriticalSection: %s." % self._handle) - if not context.executing_eagerly(): - # Collections and op introspection does not work in eager - # mode. This is generally ok; since eager mode (as of - # writing) executes sequentially anyway. - for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - sg_handle_name = ops.convert_to_tensor(sg.handle).name - self_handle_name = ops.convert_to_tensor(self._handle).name - if sg_handle_name == self_handle_name: - # Other executions in the same critical section are allowed. - continue - if not (exclusive_resource_access or sg.exclusive_resource_access): - # Neither execution requested exclusive access. - continue - resource_intersection = resource_inputs.intersection(sg.resources) - if resource_intersection: - raise ValueError( - "This execution would access resources: %s. Either this " - "lock (CriticalSection: %s) or lock '%s' " - "(CriticalSection: %s) requested exclusive resource access " - "of this resource. Did you mean to call execute with keyword " - "argument exclusive_resource_access=False?" % - (list(resource_intersection), self._handle.name, - sg.op.name, sg.handle.name)) - - def identity(x): # pylint: disable=invalid-name - if isinstance(x, tensor_array_ops.TensorArray): - return x.identity() - elif isinstance(x, ops.Operation): - return control_flow_ops.group(x) - elif context.executing_eagerly() and x is None: - return None - else: - return array_ops.identity(x) - - r_flat = [identity(x) for x in nest.flatten(r)] + self._add_control_dependencies_to_lock(created_ops, lock.op) + + # captured_resources is a list of resources that are directly + # accessed only by ops created during fn(), not by any + # ancestors of those ops in the graph. + captured_resources = set([ + input_ for op in created_ops + for input_ in op.inputs + if input_.dtype == dtypes.resource + ]) + + # NOTE(ebrevdo): The only time self._is_self_handle() is True + # in this call is if one of the recently created ops, within + # the execute(), themselves attempt to access the + # CriticalSection. This will cause a deadlock. + if any(self._is_self_handle(x) for x in captured_resources): + raise ValueError("The function fn attempts to directly access the " + "CriticalSection in which it would be running. " + "This is illegal and would cause deadlocks.") + + self._check_multiple_access_to_resources( + captured_resources, exclusive_resource_access) + + r_flat = [_identity(x) for x in nest.flatten(r)] with ops.control_dependencies(r_flat): # The identity must run on the same machine as self._handle @@ -268,23 +269,105 @@ class CriticalSection(object): # Make sure that if any element of r is accessed, all of # them are executed together. - r = nest.pack_sequence_as( - r, control_flow_ops.tuple(nest.flatten(r))) + r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) with ops.control_dependencies([ensure_lock_exists]): - outputs = nest.map_structure(identity, r) + outputs = nest.map_structure(_identity, r) if not context.executing_eagerly(): signature = _ExecutionSignature( op=lock.op, handle=self._handle, - resources=list(resource_inputs), + resources=list(captured_resources), exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) return outputs + def _add_control_dependencies_to_lock(self, created_ops, lock_op): + """To avoid deadlocks, all args must be executed before lock_op.""" + # Get all arguments (explicit and captured) of all ops created by fn(). + all_args = set([input_.op for op in created_ops for input_ in op.inputs]) + all_args.update( + input_op for op in created_ops for input_op in op.control_inputs) + # Unfortunately, we can't use sets throughout because TF seems to + # create new Operation objects for the same op sometimes; and we + # can't rely on id(op). + + # pylint: disable=protected-access + all_args_dict = dict((op._id, op) for op in all_args) + + # Remove ops created within fn, or that lock_op already has a + # control dependency on. Also remove a possible self-loop. + for op in created_ops: + all_args_dict.pop(op._id, None) + for op in lock_op.control_inputs: + all_args_dict.pop(op._id, None) + for input_ in lock_op.inputs: + all_args_dict.pop(input_.op._id, None) + all_args_dict.pop(lock_op._id, None) + + all_args = all_args_dict.values() + + if not all_args: + # No control dependencies to add; return early. + return + + # This group is important: it ensures that any ops in all_args + # outside the control context of the lock_op (and this fn, which + # runs in the same context) are added to this context before + # being added to the control dependencies of lock_op. + all_args = control_flow_ops.group(*all_args) + + lock_op._add_control_input(all_args) + # pylint: enable=protected-access + + def _is_self_handle(self, x): + """Check if the tensor `x` is the same Mutex as `self._handle`.""" + return (x.op.type == "MutexV2" + # blank shared_name means the op will create a unique one. + and x.op.get_attr("shared_name") + and (x.op.get_attr("shared_name") == + self._handle.op.get_attr("shared_name")) + and (x.op.device == self._handle.op.device + or _get_colocation(x.op) == _get_colocation(self._handle.op))) + + def _check_multiple_access_to_resources( + self, captured_resources, exclusive_resource_access): + """Raise if captured_resources are accessed by another CriticalSection. + + Args: + captured_resources: Set of tensors of type resource. + exclusive_resource_access: Whether this execution requires exclusive + resource access. + + Raises: + ValueError: If any tensors in `captured_resources` are also accessed + by another `CriticalSection`, and at least one of them requires + exclusive resource access. + """ + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if self._is_self_handle(sg.handle): + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + resource_intersection = captured_resources.intersection(sg.resources) + if resource_intersection: + raise ValueError( + "This execution would access resources: %s. Either this " + "lock (CriticalSection: %s) or lock '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource. Did you mean to call execute with keyword " + "argument exclusive_resource_access=False?" % + (list(resource_intersection), self._handle.name, + sg.op.name, sg.handle.name)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # def to_proto(self, export_scope=None): diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index c916592ce1979fe3a79cf28ad4bdac44284cce97..ba660295cb3c97d26da7bf892c78bceee53cf2d4 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging # TODO(ebrevdo): Re-enable once CriticalSection is in core. # from tensorflow.python.training import saver as saver_lib @@ -37,7 +38,7 @@ class CriticalSectionTest(test.TestCase): v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): - c = v.read_value() + c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): @@ -140,15 +141,151 @@ class CriticalSectionTest(test.TestCase): ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) def testRecursiveCriticalSectionAccessIsIllegal(self): + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection() + def fn(x): + return cs.execute(lambda y: y + 1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to directly access the CriticalSection in which it " + r"would be running"): + cs.execute(fn, 1.0) + + def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): + # This one is subtle; and we're being overly cautious here. The + # deadlock we are ensuring we catch is: + # + # to_capture = CS[lambda x: x + 1](1.0) + # deadlocked = CS[lambda x: x + to_capture](1.0) + # + # This would have caused a deadlock because executing `deadlocked` will + # lock the mutex on CS; but then due to dependencies, will attempt + # to compute `to_capture`. This computation requires locking CS, + # but that is not possible now because CS is already locked by + # `deadlocked`. + # + # We check that CriticalSection.execute properly inserts new + # control dependencies to its lock to ensure all captured + # operations are finished before anything runs within the critical section. + cs = critical_section_ops.CriticalSection(shared_name="cs") + fn = array_ops.identity + to_capture = cs.execute(fn, 1.0) + fn_captures = lambda x: x + to_capture + to_capture_too = array_ops.identity(to_capture) + + ex_0 = cs.execute(fn_captures, 1.0) + + with ops.control_dependencies([to_capture]): + # This is OK because to_capture will execute before this next call + ex_1 = cs.execute(fn_captures, 1.0) + + dependency = array_ops.identity(to_capture) + + fn_captures_dependency = lambda x: x + dependency + + ex_2 = cs.execute(fn_captures_dependency, 1.0) + + with ops.control_dependencies([to_capture_too]): + ex_3 = cs.execute(fn_captures_dependency, 1.0) + + # Ensure there's no actual deadlock on to_execute. + self.assertEquals(2.0, self.evaluate(ex_0)) + self.assertEquals(2.0, self.evaluate(ex_1)) + self.assertEquals(2.0, self.evaluate(ex_2)) + self.assertEquals(2.0, self.evaluate(ex_3)) + + def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): + cs = critical_section_ops.CriticalSection(shared_name="cs") + + def body_implicit_capture(i, j): + # This would have caused a deadlock if not for logic in execute + # that inserts additional control dependencies onto the lock op: + # * Loop body argument j is captured by fn() + # * i is running in parallel to move forward the execution + # * j is not being checked by the predicate function + # * output of cs.execute() is returned as next j. + fn = lambda: j + 1 + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + + def body_implicit_capture_protected(i, j): + # This version is ok because we manually add a control + # dependency on j, which is an argument to the while_loop body + # and captured by fn. + fn = lambda: j + 1 + with ops.control_dependencies([j]): + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture_protected, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + + def body_args_capture(i, j): + # This version is ok because j is an argument to fn and we can + # ensure there's a control dependency on j. + fn = lambda x: x + 1 + return (i + 1, cs.execute(fn, j)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_args_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + + def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): # This does not work properly in eager mode. Eager users will # just hit a deadlock if they do this. But at least it'll be easier # to debug. cs = critical_section_ops.CriticalSection(shared_name="cs") + cs_same = critical_section_ops.CriticalSection(shared_name="cs") def fn(x): - return cs.execute(lambda x: x+1, x) + return cs_same.execute(lambda x: x+1, x) with self.assertRaisesRegexp( ValueError, - r"attempts to access the CriticalSection in which it would be running"): + r"attempts to directly access the CriticalSection in which it " + r"would be running"): cs.execute(fn, 1.0) def testMultipleCSExecutionsRequestSameResource(self): @@ -179,6 +316,20 @@ class CriticalSectionTest(test.TestCase): ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1) + def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0, name="v") + # Make sure that the control dependencies on v do not cause issues + # in the lock_op's automatic control dependency adder. + # + # Note, here v must be a resource variable (or something similar), + # otherwise it gets hoisted into the while_loop by the time we add + # control dependencies to the lock_op. + out = control_flow_ops.while_loop( + lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + self.evaluate(v.initializer) + self.assertEqual(10, self.evaluate(out)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 8f62f0ea7b9b561f235b9496ffda97a9f378d530..1921a77c1e96ee3531d1ed0f98e41c27c9d427ac 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Support for sorting tensors. +@@argsort @@sort """ @@ -21,6 +22,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -47,64 +51,141 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): ValueError: If axis is not a constant scalar, or the direction is invalid. """ with framework_ops.name_scope(name, 'sort'): - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error + return _sort_or_argsort(values, axis, direction, return_argsort=False) + + +def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): + """Returns the indices of a tensor that give its sorted order along an axis. + + For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to + `tf.sort(values)`. For higher dimensions, the output has the same shape as + `values`, but along the given axis, values represent the index of the sorted + element in that slice of the tensor at the given position. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + stable: If True, equal elements in the original tensor will not be + re-ordered in the returned order. Unstable sort is not yet implemented, + but will eventually be the default for performance reasons. If you + require a stable order, pass `stable=True` for forwards compatibility. + name: Optional name for the operation. + + Returns: + An int32 `Tensor` with the same shape as `values`. The indices that would + sort each slice of the given `values` along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + del stable # Unused. + with framework_ops.name_scope(name, 'argsort'): + return _sort_or_argsort(values, axis, direction, return_argsort=True) + + +def _sort_or_argsort(values, axis, direction, return_argsort): + """Internal sort/argsort implementation. + + Args: + values: The input values. + axis: The axis along which to sort. + direction: 'ASCENDING' or 'DESCENDING'. + return_argsort: Whether to return the argsort result. + + Returns: + Either the sorted values, or the indices of the sorted values in the + original tensor. See the `sort` and `argsort` docstrings. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error - values = framework_ops.convert_to_tensor(values, name='values') + values = framework_ops.convert_to_tensor(values, name='values') - return _SORT_IMPL[direction](values, axis_static) + return _SORT_IMPL[direction](values, axis_static, return_argsort) -def _descending_sort(values, axis): +def _descending_sort(values, axis, return_argsort=False): """Sorts values in reverse using `top_k`. Args: values: Tensor of numeric values. axis: Index of the axis which values should be sorted along. + return_argsort: If False, return the sorted values. If True, return the + indices that would sort the values. Returns: The sorted values. """ k = array_ops.shape(values)[axis] rank = array_ops.rank(values) + static_rank = values.shape.ndims # Fast path: sorting the last axis. if axis == -1 or axis + 1 == values.get_shape().ndims: - return nn_ops.top_k(values, k)[0] - - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Make axis a Tensor with the real axis index if needed. - axis += rank - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - values, unused_indices = nn_ops.top_k(top_k_input, k) - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return array_ops.transpose(values, transposition) - - -def _ascending_sort(values, axis): + top_k_input = values + transposition = None + else: + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Calculate the actual axis index if counting from the end. Use the static + # rank if available, or else make the axis back into a tensor. + axis += static_rank or rank + if static_rank is not None: + # Prefer to calculate the transposition array in NumPy and make it a + # constant. + transposition = constant_op.constant( + np.r_[ + # Axes up to axis are unchanged. + np.arange(axis), + # Swap axis and rank - 1. + [static_rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + np.arange(axis + 1, static_rank - 1), + # Swap axis and rank - 1. + [axis]], + name='transposition') + else: + # Generate the transposition array from the tensors. + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + + values, indices = nn_ops.top_k(top_k_input, k) + return_value = indices if return_argsort else values + if transposition is not None: + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return_value = array_ops.transpose(return_value, transposition) + return return_value + + +def _ascending_sort(values, axis, return_argsort=False): # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis) - return -values_or_indices + values_or_indices = _descending_sort(-values, axis, return_argsort) + # If not argsort, negate the values again. + return values_or_indices if return_argsort else -values_or_indices _SORT_IMPL = { diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index d08ae502f10d98ee14d8bea2f76b18bedb935cea..a8fb94b245dccc8c7cf0e94cef9b436f881fe408 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -24,6 +24,8 @@ from tensorflow.contrib.framework.python.ops import sort_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -90,6 +92,38 @@ class SortTest(test.TestCase): axis=0, direction='DESCENDING').eval()) + def testSort_staticallyKnownRank_constantTransposition(self): + # The transposition array should be a constant if the rank of "values" is + # statically known. + tensor = random_ops.random_uniform( + # Rank is statically known to be 5, but the dimension lengths are not + # known. + random_ops.random_uniform( + shape=(5,), minval=0, maxval=10, dtype=dtypes.int32)) + sort_ops.sort(tensor, axis=1) + transposition = ( + ops.get_default_graph().get_tensor_by_name('sort/transposition:0')) + self.assertFalse(tensor_util.constant_value(transposition) is None) + self.assertAllEqual( + # Swaps "1" and "4" to put "1" at the end. + tensor_util.constant_value(transposition), + [0, 4, 2, 3, 1]) + + def testArgsort_1d(self): + arr = np.random.random(42) + with self.test_session(): + self.assertAllEqual( + np.sort(arr), + array_ops.gather(arr, sort_ops.argsort(arr)).eval()) + + def testArgsort(self): + arr = np.random.random((5, 6, 7, 8)) + for axis in range(4): + with self.test_session(): + self.assertAllEqual( + np.argsort(arr, axis=axis), + sort_ops.argsort(arr, axis=axis).eval()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index ce37672895b37275770d2f5410f662e9acf1de9d..0eb6889db1fae1c74aeb4392441b308392b091a5 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -157,15 +157,3 @@ cuda_py_test( "requires_cudnn6", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index ff6f3b744190c9a7c74fb88878e5f13412251e79..461066bbb493932b342cee8f8842e899a2d84fff 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -545,15 +545,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 082c42eba180917e732bb7890129dfa94bf00fec..e3fc6bf0f034051fc33ff5966e2f4ea85aa538db 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -88,8 +88,8 @@ class GANEstimator(estimator.Estimator): discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) # Train estimator. gan_estimator.train(train_input_fn, steps) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 4811edcbcfa63e99210b3c2f416b71bb83915869..47e51415fd9e7daa360ca06a11078f6edcf63b5b 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader - __all__ = [ 'get_graph_def_from_disk', 'get_graph_def_from_resource', @@ -62,10 +62,11 @@ __all__ = [ 'frechet_inception_distance', 'frechet_classifier_distance', 'frechet_classifier_distance_from_activations', + 'mean_only_frechet_classifier_distance_from_activations', + 'diagonal_only_frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] - INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' INCEPTION_INPUT = 'Mul:0' @@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): images = ops.convert_to_tensor(images) images.shape.with_rank(4) - images.shape.assert_is_compatible_with( - [None, image_size, image_size, None]) + images.shape.assert_is_compatible_with([None, image_size, image_size, None]) return images @@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -def preprocess_image( - images, height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): +def preprocess_image(images, + height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, + scope=None): """Prepare a batch of images for evaluation. This is the preprocessing portion of the graph from @@ -272,8 +273,11 @@ def run_inception(images, return activations -def run_image_classifier(tensor, graph_def, input_tensor, - output_tensor, scope='RunClassifier'): +def run_image_classifier(tensor, + graph_def, + input_tensor, + output_tensor, + scope='RunClassifier'): """Runs a network from a frozen graph. Args: @@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul( - sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, + math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) @@ -452,7 +456,7 @@ def frechet_classifier_distance(real_images, Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the @@ -511,10 +515,142 @@ def frechet_classifier_distance(real_images, return frechet_classifier_distance_from_activations(real_a, gen_a) -def frechet_classifier_distance_from_activations( +def mean_only_frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + In this variant, we only compute the difference between the means of the + fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet + still retains much of the same information as FID. + + Args: + real_activations: 2D array of activations of real images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + generated_activations: 2D array of activations of generated images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + + Returns: + The mean-only Frechet Inception distance. A floating-point scalar of the + same type as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute means of activations. + m = math_ops.reduce_mean(real_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + mofid = mean + if activations_dtype != dtypes.float64: + mofid = math_ops.cast(mofid, activations_dtype) + + return mofid + + +def diagonal_only_frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. In this variant, we compute diagonal-only covariance matrices. + As a result, instead of computing an expensive matrix square root, we can do + something much simpler, and has O(n) vs O(n^2) space complexity. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The diagonal-only Frechet Inception distance. A floating-point scalar of + the same type as the output of the activations. + + Raises: + ValueError: If the shape of the variance and mean vectors are not equal. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute mean and covariance matrices of activations. + m, var = nn_impl.moments(real_activations, axes=[0]) + m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) + + actual_shape = var.get_shape() + expected_shape = m.get_shape() + + if actual_shape != expected_shape: + raise ValueError('shape: {} must match expected shape: {}'.format( + actual_shape, expected_shape)) + + # Compute the two components of FID. + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.reduce_sum( + (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + dofid = trace + mean + if activations_dtype != dtypes.float64: + dofid = math_ops.cast(dofid, activations_dtype) + + return dofid + + +def frechet_classifier_distance_from_activations(real_activations, + generated_activations): + """Classifier distance for evaluating a generative model. + This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large @@ -525,13 +661,20 @@ def frechet_classifier_distance_from_activations( Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. @@ -553,36 +696,38 @@ def frechet_classifier_distance_from_activations( # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) - m_v = math_ops.reduce_mean(generated_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / ( + num_examples - 1) - gen_centered = generated_activations - m_v - sigma_v = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_w + sigma_w = math_ops.matmul( + gen_centered, gen_centered, transpose_a=True) / ( + num_examples - 1) - # Find the Tr(sqrt(sigma sigma_v)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) + # Find the Tr(sqrt(sigma sigma_w)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid - frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 61dc8646ddc10605561ae6b19e90f4739c346608..663e49bdca3cb2dd9257da326488c877fcc4256d 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -50,6 +50,26 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) +def _expected_mean_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + mean = np.square(m - m_v).sum() + mofid = mean + return mofid + + +def _expected_diagonal_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + var = np.var(real_imgs, axis=0) + var_v = np.var(gen_imgs, axis=0) + sqcc = np.sqrt(var * var_v) + mean = (np.square(m - m_v)).sum() + trace = (var + var_v - 2 * sqcc).sum() + dofid = mean + trace + return dofid + + def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) @@ -285,6 +305,46 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) + def test_mean_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_mofid = sess.run(mofid_op) + + expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_mofid, actual_mofid, 0.0001) + + def test_diagonal_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_dofid = sess.run(dofid_op) + + expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_dofid, actual_dofid, 0.0001) + def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 0d1afad72da8a8e087239868e25ddebe23490d1e..508f487722fba89cc8391a340f73673a526e86c4 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -31,6 +31,7 @@ __all__ = [ 'add_image_comparison_summaries', 'add_gan_model_summaries', 'add_regularization_loss_summaries', + 'add_cyclegan_image_summaries', ] @@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): ValueError: If real and generated data aren't images. """ if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_summaries'): - add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_summaries'): - add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) - return - + raise ValueError( + '`add_gan_model_image_summaries` does not take CycleGANModels. Please ' + 'use `add_cyclegan_image_summaries` instead.') _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): add_gan_model_summaries(gan_model) +def add_cyclegan_image_summaries(cyclegan_model): + """Adds image summaries for CycleGAN. + + There are two summaries, one for each generator. The first image is the + generator input, the second is the generator output, and the third is G(F(x)). + + Args: + cyclegan_model: A CycleGANModel tuple. + + Raises: + ValueError: If `cyclegan_model` isn't a CycleGANModel. + ValueError: If generated data, generator inputs, and reconstructions aren't + images. + ValueError: If the generator input, generated data, and reconstructions + aren't all the same size. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was ' + '%s' % type(cyclegan_model)) + + _assert_is_image(cyclegan_model.model_x2y.generator_inputs) + _assert_is_image(cyclegan_model.model_x2y.generated_data) + _assert_is_image(cyclegan_model.reconstructed_x) + _assert_is_image(cyclegan_model.model_y2x.generator_inputs) + _assert_is_image(cyclegan_model.model_y2x.generated_data) + _assert_is_image(cyclegan_model.reconstructed_y) + + def _add_comparison_summary(gan_model, reconstructions): + image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) + + array_ops.unstack(gan_model.generated_data[:1]) + + array_ops.unstack(reconstructions[:1])) + summary.image( + 'image_comparison', eval_utils.image_reshaper( + image_list, num_cols=len(image_list)), max_outputs=1) + + with ops.name_scope('x2y_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_x2y, cyclegan_model.reconstructed_x) + with ops.name_scope('y2x_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_y2x, cyclegan_model.reconstructed_y) + + def add_image_comparison_summaries(gan_model, num_comparisons=2, display_diffs=False): """Adds image summaries to compare triplets of images. @@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ - if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_y2x, **saved_params) - return - _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 45eb108586bed07434ac29595164745eac6054c1..33d51bfc218ab93fb52439b1eefed98a4568c4a1 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -65,15 +65,14 @@ def get_cyclegan_model(): return namedtuples.CycleGANModel( model_x2y=model_x2y, model_y2x=model_y2x, - reconstructed_x=array_ops.zeros([3, 30, 35, 6]), - reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + reconstructed_x=array_ops.zeros([4, 32, 32, 3]), + reconstructed_y=array_ops.zeros([4, 32, 32, 3])) class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops, - model_summaries): + def _test_add_gan_model_image_summaries_impl( + self, get_model_fn, expected_num_summary_ops, model_summaries): summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, model_summaries=model_summaries) @@ -89,8 +88,9 @@ class SummariesTest(test.TestCase): def test_add_gan_model_image_summaries_no_model(self): self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) + def test_cyclegan_image_summaries_dont_work(self): + with self.assertRaises(ValueError): + summaries.add_gan_model_image_summaries(get_cyclegan_model()) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): @@ -137,7 +137,11 @@ class SummariesTest(test.TestCase): self._test_add_image_comparison_summaries_impl(get_gan_model, 1) def test_add_image_comparison_summaries_for_cyclegan(self): - self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + summaries.add_cyclegan_image_summaries(get_cyclegan_model()) + + self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + with self.test_session(use_gpu=True): + summary.merge_all().eval() if __name__ == '__main__': diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 39588b7219ebac1cc4855532be3fcc38e6381134..1ba3a641671c7f2a411a0c5f99228ca16eee1080 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -306,6 +306,7 @@ def wasserstein_gradient_penalty( discriminator_scope, epsilon=1e-10, target=1.0, + one_sided=False, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -327,6 +328,8 @@ def wasserstein_gradient_penalty( computing the gradient norm. target: Optional Python number or `Tensor` indicating the target value of gradient norm. Defaults to 1.0. + one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 + is used. Defaults to `False`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -377,10 +380,13 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes / target - 1.0) + penalties = slopes / target - 1.0 + if one_sided: + penalties = math_ops.maximum(0., penalties) + penalties_squared = math_ops.square(penalties) penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + penalties_squared, weights, scope=scope, + loss_collection=loss_collection, reduction=reduction) if add_summaries: summary.scalar('gradient_penalty_loss', penalty) @@ -665,7 +671,7 @@ def least_squares_discriminator_loss( loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): - """Least squares generator loss. + """Least squares discriminator loss. This loss comes from `Least Squares Generative Adversarial Networks` (https://arxiv.org/abs/1611.04076). diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index dbaa624ae9d6a5a5949db692e52c0c1deb18b8df..2889e937436d2faa66b5693c19046e122cbaf652 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,28 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_using_one_sided_mode(self): + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + one_sided=True) + self.assertEqual(generated_data.dtype, loss.dtype) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run(loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): """Test loss value with non default gradient norm target.""" generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 776eb11ecb1624544d24611d8fe6ca19768b8313..73acd05b60a5fb02601423fd9234a56a34f75276 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -461,6 +461,7 @@ def gan_loss( gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -485,6 +486,8 @@ def gan_loss( gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python number or `Tensor` indicating the target value of gradient norm. See the CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. + gradient_penalty_one_sided: If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -546,6 +549,7 @@ def gan_loss( model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index f9bdaa74c948ecee11d5cfd89f06087924f8dace..3ebbe55d059e5e72607bc4efdbf95a6c96d99f11 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -359,10 +359,12 @@ class GANLossTest(test.TestCase): self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn): + def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) + loss_gp = train.gan_loss(model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) # Check values. @@ -394,6 +396,25 @@ class GANLossTest(test.TestCase): def test_grad_penalty_callable_acgan(self): self._test_grad_penalty_helper(create_callable_acgan_model) + def test_grad_penalty_one_sided_gan(self): + self._test_grad_penalty_helper(create_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_gan(self): + self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_infogan(self): + self._test_grad_penalty_helper(create_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_infogan(self): + self._test_grad_penalty_helper( + create_callable_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_acgan(self): + self._test_grad_penalty_helper(create_acgan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_acgan(self): + self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) + # Test mutual information penalty option. def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss(create_gan_model_fn(), diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 707ae25d485c64f15694ee0e357f32b619d3cd33..e534fdc17749974ebe713c2730682bea6d7a85e4 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -9,18 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - filegroup( name = "c_srcs", data = glob([ diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 967ad2fc090906e93f22c777816eede37f9a1b04..1711100e3a857dba0d15c5b4f6c96cddc568e800 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -39,18 +39,6 @@ py_library( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "match", srcs = ["tests/match.py"], diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ca00394388f67e2ed9508684a47b23c3ee9e79e8..2603de640735a612cbd883cc6227fe3cd9f11fca 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -84,9 +85,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op): + def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -201,15 +202,56 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEquals(original_mul1_grad._original_op.name, u"mul1") - self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEquals(res.name, g.name) + self.assertEqual(original_mul1_grad._original_op.name, u"mul1") + self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + def test_graph_while_loop(self): + graph = ops.Graph() + with graph.as_default(): + max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) + index_start = constant_op.constant(1) + sum_start = constant_op.constant(0) + _, result = control_flow_ops.while_loop( + cond=lambda i, unused_s: i <= max_index, + body=lambda i, s: (i + 1, s + i), + loop_vars=[index_start, sum_start]) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_max_index = copy_info.transformed(max_index) + with copied_graph.as_default(): + with session.Session() as sess: + n = 10 + sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) + self.assertEqual(sum_val, 55) + + def test_graph_cond(self): + graph = ops.Graph() + with graph.as_default(): + choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) + result = control_flow_ops.cond( + choice, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_choice = copy_info.transformed(choice) + with copied_graph.as_default(): + with session.Session() as sess: + res = sess.run(copied_result, feed_dict={copied_choice: True}) + self.assertEqual(res, 1) + res = sess.run(copied_result, feed_dict={copied_choice: False}) + self.assertEqual(res, 2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 14ac5296657d48c7f9e94d220c9e7e28af4d4353..d8a48387a745e7d88cc6a74c96cb21a2ba1cfa1f 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,20 +129,26 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. + new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ + # The `new_inputs` was added to this function. For compatibility reason, + # let's raise an error if `new_inputs` is a boolean. + if isinstance(new_inputs, bool): + raise TypeError("the `new_inputs` argument must be an iterable.") + # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op._node_def) + node_def_ = deepcopy(op.node_def) # Transform name: name_ = info.new_name(op.name) @@ -155,10 +161,10 @@ def copy_op_handler(info, op, copy_shape=True): # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op._op_def) + op_def_ = deepcopy(op.op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -170,6 +176,7 @@ def copy_op_handler(info, op, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. + # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op @@ -328,6 +335,14 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler + # The graph is transformed op by op, in the same order the original ops + # were created. However, this is sometimes not possible due to cycles + # (i.e. while loops). So when the transformer creates a new op whose + # inputs do not exist yet, temporary placeholders are created and stored + # in this `tmp_cyclic_ts` container. During a second pass, + # those temporary tensors are replaced by the proper transformed tensors + # (see the function `_finalize_cycles`). + self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -428,10 +443,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) - info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._connect_ops(info) + self._finalize_cycles(info) + self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -440,10 +455,10 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - for op in info.sgv.ops: - logging.debug("Copying op: %s", op.name) - # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op) + sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access + for op in sorted_ops: + new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] + op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -456,27 +471,36 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _connect_ops(self, info): + def _finalize_cycles(self, info): + """Reconnects the cyclic tensors.""" + for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: + if t not in info.transformed_ts: + raise ValueError("The tensor {} should be transformed by now.".format( + t.name)) + if consumer_op not in info.transformed_ops: + raise ValueError("The op {} should be transformed by now.".format( + consumer_op.name)) + t_ = info.transformed_ts[t] + consumer_op_ = info.transformed_ops[consumer_op] + t_index_ = list(consumer_op_.inputs).index(tmp_t_) + consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + + def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Finalizing op: %s", op.name) + logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] - # pylint: disable=protected-access - if op_.inputs: - raise ValueError("The newly transformed op should not have " - "any inputs yet: {}".format(op_.name)) - inputs_ = [self._transformed_t(info, t) for t in op.inputs] - for t in inputs_: - op_._add_input(t) - # Finalize original op. + # TODO(fkp): Stop worrying about _original_op and remove this code? + # pylint: disable=protected-access if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) + original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op + # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -525,19 +549,38 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transformed_t(self, info, t): + def _transformed_t(self, info, t, consumer_op): """Return tre transformed tensor of `t`.""" - if t not in info.transformed_ts: - # If op is not in the subgraph. - if t in info.sgv_inputs_set: - # t is an input of the subgraph. - return self.transform_external_input_handler(info, t) + if t in info.transformed_ts: + # If op is in the subgraph, just return its transformed counterpart. + return info.transformed_ts[t] + + if t in info.sgv_inputs_set: + # `t` is an input of the subgraph. + return self.transform_external_input_handler(info, t) + elif t.op in info.ops: + # `t` is an internal tensor but is not transformed yet because it + # belongs to a graph cycle. + logging.debug("Cyclic tensor: t.name = %s", t.name) + # Try to find an existing tensor we can use for now, + # otherwise create one. We'll rewire this later. + if consumer_op.type == "Merge": + first_input = consumer_op.inputs[0] + tmp_t_ = self._transformed_t(info, first_input, consumer_op) + elif t.op.type == "Enter": + enter_input = t.op.inputs[0] + tmp_t_ = self._transformed_t(info, enter_input, consumer_op) else: - # t is a hidden input of the subgraph. - return self.transform_external_hidden_input_handler(info, t) + with info.graph_.as_default(): + tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, + prefix="geph_tmp") + logging.debug("Created temporary placeholder: %s.", tmp_t_.name) + # Register as temporary and return. + info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) + return tmp_t_ else: - # If op is in the subgraph, just return its transformed. - return info.transformed_ts[t] + # `t` is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -624,6 +667,40 @@ def copy_with_input_replacements(sgv, replacement_ts, sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) +def _add_control_flow_ops(ops, control_ios): + """Complete `ops` so that the tranformed graph is valid. + + Partially copying a graph can lead to a malformed graph. For instance, + copying half of a while construct is likely to result in an invalid graph. + This function attempts to add missing ops so that the transformation result + in a valid graph. + + Args: + ops: list of ops (modifed in-place). + control_ios: object created by a call to `util.ControlOutputs`. + """ + # Find while contexts. + control_flow_contexts = set() + for op in ops: + cfc = op._control_flow_context # pylint: disable=protected-access + if cfc: + control_flow_contexts.add(cfc) + # Find new ops. + new_ops = [] + for cfc in control_flow_contexts: + if cfc.IsWhileContext(): + new_ops += select.get_walks_intersection_ops( + [enter_t.op for enter_t in cfc.loop_enters], + [exit_t.op for exit_t in cfc.loop_exits], + control_ios=control_ios) + # Add new ops. + new_ops_set = set(new_ops) + ops_set = frozenset(ops) + for op in new_ops_set: + if op not in ops_set: + ops.append(op) + + def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which compute the targets from the replaced Tensors. @@ -657,8 +734,13 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") + + # Complete ops to avoid malformed control flow. + # TODO(fkp): Consider moving this function deeper (in the transformer?). + _add_control_flow_ops(ops, control_ios) + # Create a copy of the relevant subgraph - _, info = copy_with_input_replacements( + unused_sgv_, info = copy_with_input_replacements( ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef..584f4509ccc0aab30edc2be3bad7a9cb938d6e6a 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -38,6 +38,11 @@ __all__ = [ ] +# The graph editor sometimes need to create placeholders, they are named +# "geph_*". "geph" stands for Graph-Editor PlaceHolder. +_DEFAULT_PLACEHOLDER_PREFIX = "geph" + + def concatenate_unique(la, lb): """Add all the elements of `lb` to `la` if they are not there already. @@ -405,7 +410,7 @@ def scope_basename(scope): return scope[slash + 1:] -def placeholder_name(t=None, scope=None): +def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create placeholder name for the graph editor. Args: @@ -413,6 +418,7 @@ def placeholder_name(t=None, scope=None): on scope: absolute scope with which to prefix the placeholder's name. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A new placeholder name prefixed by "geph". Note that "geph" stands for Graph Editor PlaceHolder. This convention allows to quickly identify the @@ -430,19 +436,20 @@ def placeholder_name(t=None, scope=None): if scope is None: scope = op_dirname - if op_basename.startswith("geph__"): + if op_basename.startswith("{}__".format(prefix)): ph_name = op_basename else: - ph_name = "geph__{}_{}".format(op_basename, t.value_index) + ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) return scope + ph_name else: if scope is None: scope = "" - return scope + "geph" + return "{}{}".format(scope, prefix) -def make_placeholder_from_tensor(t, scope=None): +def make_placeholder_from_tensor(t, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -452,17 +459,19 @@ def make_placeholder_from_tensor(t, scope=None): (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. + prefix: placeholder name prefix. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( - dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( - t, scope=scope)) + dtype=t.dtype, shape=t.get_shape(), + name=placeholder_name(t, scope=scope, prefix=prefix)) -def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): +def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -474,11 +483,13 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( - dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + dtype=dtype, shape=shape, + name=placeholder_name(scope=scope, prefix=prefix)) _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index d601a1ec6f7a219bcd461d819ab2dfc64135a3ae..d0b44640667010b58c017d933d50ae5f87e8b275 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -41,15 +41,3 @@ cuda_py_tests( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index 1b528d7afc1112f5dc0667ae299ade02bc8fd04b..d65b2d6026dd89959aa62b57e07b073eef84572c 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -23,14 +23,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD index 324035100df366b80f57af9052c4bd935655b248..e39c60b252a1b49a68d51302fff47734869dddfe 100644 --- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD +++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD @@ -13,18 +13,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//visibility:public"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_cc_binary( name = "clock_cycle_profiling", testonly = 1, diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 909dc396a33b6fef1b2d51c3f52fab7782fc8ea5..0081fb61770075a2c36e92f65e01126f657edeb4 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -10,17 +10,6 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - tf_cc_binary( name = "hvx_ops_support_checker", testonly = 1, diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 3ff02e085ee63fabf42b3cc4389f4605455f3800..da450480b30b548484e69c61c85667d6dd390417 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -78,7 +78,10 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":dense_image_warp_py", ":image_ops", + ":interpolate_spline_py", + ":sparse_image_warp_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", @@ -194,6 +197,117 @@ cuda_py_test( ], ) +py_library( + name = "dense_image_warp_py", + srcs = [ + "python/ops/dense_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_library( + name = "interpolate_spline_py", + srcs = [ + "python/ops/interpolate_spline.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_library( + name = "sparse_image_warp_py", + srcs = [ + "python/ops/sparse_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dense_image_warp_py", + ":interpolate_spline_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "sparse_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/sparse_image_warp_test.py"], + additional_deps = [ + ":sparse_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], + data = [":sparse_image_warp_test_data"], + tags = ["no_pip"], +) + +filegroup( + name = "sparse_image_warp_test_data", + srcs = glob(["python/kernel_tests/test_data/*.png"]), +) + +cuda_py_test( + name = "dense_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/dense_image_warp_test.py"], + additional_deps = [ + ":dense_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + +cuda_py_test( + name = "interpolate_spline_test", + size = "medium", + srcs = ["python/kernel_tests/interpolate_spline_test.py"], + additional_deps = [ + ":interpolate_spline_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + tf_py_test( name = "segmentation_test", size = "medium", @@ -270,15 +384,3 @@ cuda_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index cc8ed117ba2edcc7a53e609381166f17a2fbb45e..e982030bc8959309e72d0f4e02b9755c48535a10 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -30,6 +30,9 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms +@@dense_image_warp +@@interpolate_spline +@@sparse_image_warp ## Image Segmentation `Ops` @@ -47,6 +50,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.python.ops.dense_image_warp import dense_image_warp + from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq @@ -57,7 +62,9 @@ from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms +from tensorflow.contrib.image.python.ops.interpolate_spline import interpolate_spline from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms +from tensorflow.contrib.image.python.ops.sparse_image_warp import sparse_image_warp from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a58b6a247ed6ae252db25a12f1e47c08c9a5c147 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -0,0 +1,267 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import adam + + +class DenseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def test_interpolate_small_grid_ij(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_xy(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear( + grid, query_points, indexing='xy') + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_batched(self): + grid = constant_op.constant( + [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1]) + query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], + [[0.5, 0.], [1., 0.], [1., 1.]]]) + expected_results = np.reshape( + np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def get_image_and_flow_placeholders(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + flow_shape = [batch_size, height, width, 2] + + tf_type = { + 'float16': dtypes.half, + 'float32': dtypes.float32, + 'float64': dtypes.float64 + } + + image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape) + + flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape) + return image, flows + + def get_random_image_and_flows(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + image = np.random.normal(size=image_shape) + flow_shape = [batch_size, height, width, 2] + flows = np.random.normal(size=flow_shape) * 3 + return image.astype(image_type), flows.astype(flow_type) + + def assert_correct_interpolation_value(self, + image, + flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=False): + """Assert that the tf interpolation matches hand-computed value.""" + + height = image.shape[1] + width = image.shape[2] + displacement = flows[batch_index, y_index, x_index, :] + float_y = y_index - displacement[0] + float_x = x_index - displacement[1] + floor_y = max(min(height - 2, math.floor(float_y)), 0) + floor_x = max(min(width - 2, math.floor(float_x)), 0) + ceil_y = floor_y + 1 + ceil_x = floor_x + 1 + + alpha_y = min(max(0.0, float_y - floor_y), 1.0) + alpha_x = min(max(0.0, float_x - floor_x), 1.0) + + floor_y = int(floor_y) + floor_x = int(floor_x) + ceil_y = int(ceil_y) + ceil_x = int(ceil_x) + + top_left = image[batch_index, floor_y, floor_x, :] + top_right = image[batch_index, floor_y, ceil_x, :] + bottom_left = image[batch_index, ceil_y, floor_x, :] + bottom_right = image[batch_index, ceil_y, ceil_x, :] + + interp_top = alpha_x * (top_right - top_left) + top_left + interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left + interp = alpha_y * (interp_bottom - interp_top) + interp_top + atol = 1e-6 + rtol = 1e-6 + if low_precision: + atol = 1e-2 + rtol = 1e-3 + self.assertAllClose( + interp, + pred_interpolation[batch_index, y_index, x_index, :], + atol=atol, + rtol=rtol) + + def check_zero_flow_correctness(self, shape, image_type, flow_type): + """Assert using zero flows doesn't change the input image.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + rand_flows *= 0 + + predicted_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + self.assertAllClose(rand_image, predicted_interpolation) + + def test_zero_flows(self): + """Apply check_zero_flow_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] + for shape in shapes_to_try: + self.check_zero_flow_correctness( + shape, image_type='float32', flow_type='float32') + + def check_interpolation_correctness(self, + shape, + image_type, + flow_type, + num_probes=5): + """Interpolate, and then assert correctness for a few query locations.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + low_precision = image_type == 'float16' or flow_type == 'float16' + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + + pred_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + + for _ in range(num_probes): + batch_index = np.random.randint(0, shape[0]) + y_index = np.random.randint(0, shape[1]) + x_index = np.random.randint(0, shape[2]) + + self.assert_correct_interpolation_value( + rand_image, + rand_flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=low_precision) + + def test_interpolation(self): + """Apply check_interpolation_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] + for im_type in ['float32', 'float64', 'float16']: + for flow_type in ['float32', 'float64', 'float16']: + for shape in shapes_to_try: + self.check_interpolation_correctness(shape, im_type, flow_type) + + def test_gradients_exist(self): + """Check that backprop can run. + + The correctness of the gradients is assumed, since the forward propagation + is tested to be correct and we only use built-in tf ops. + However, we perform a simple test to make sure that backprop can actually + run. We treat the flows as a tf.Variable and optimize them to minimize + the difference between the interpolated image and the input image. + """ + + batch_size, height, width, numchannels = [4, 5, 6, 7] + image_shape = [batch_size, height, width, numchannels] + image = random_ops.random_normal(image_shape) + flow_shape = [batch_size, height, width, 2] + init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) + flows = variables.Variable(init_flows) + + interp = dense_image_warp.dense_image_warp(image, flows) + loss = math_ops.reduce_mean(math_ops.square(interp - image)) + + optimizer = adam.AdamOptimizer(1.0) + grad = gradients.gradients(loss, [flows]) + opt_func = optimizer.apply_gradients(zip(grad, [flows])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(10): + sess.run(opt_func) + + def test_size_exception(self): + """Make sure it throws an exception for images that are too small.""" + + shape = [1, 2, 1, 1] + msg = 'Should have raised an exception for invalid image size' + with self.assertRaises(ValueError, msg=msg): + self.check_interpolation_correctness(shape, 'float32', 'float32') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1939caaa2d8586413cf9ecba6ce73cf64910d6fc --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for interpolate_spline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import interpolate as sc_interpolate + +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util + +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import momentum + + +class _InterpolationProblem(object): + """Abstract class for interpolation problem descriptions.""" + + def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'): + """Make data for an interpolation problem where all x vectors are n-d. + + Args: + optimizable: If True, then make train_points a tf.Variable. + extrapolate: If False, then clamp the query_points values to be within + the max and min of train_points. + dtype: The data type to use. + + Returns: + query_points, query_values, train_points, train_values: training and + test tensors for interpolation problem + """ + + # The values generated here depend on a seed of 0. + np.random.seed(0) + + batch_size = 1 + num_training_points = 10 + num_query_points = 4 + + init_points = np.random.uniform( + size=[batch_size, num_training_points, self.DATA_DIM]) + + init_points = init_points.astype(dtype) + train_points = ( + variables.Variable(init_points) + if optimizable else constant_op.constant(init_points)) + train_values = self.tf_function(train_points) + + query_points_np = np.random.uniform( + size=[batch_size, num_query_points, self.DATA_DIM]) + query_points_np = query_points_np.astype(dtype) + if not extrapolate: + query_points_np = np.clip(query_points_np, np.min(init_points), + np.max(init_points)) + + query_points = constant_op.constant(query_points_np) + query_values = self.np_function(query_points_np) + + return query_points, query_values, train_points, train_values + + +class _QuadraticPlusSinProblem1D(_InterpolationProblem): + """1D interpolation problem used for regression testing.""" + DATA_DIM = 1 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], + (1.0, + 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], + (2.0, + 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], + (2.0, + 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741], + (3.0, + 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122], + (3.0, + 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857], + (4.0, + 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256], + (4.0, 0.01): [ + -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 + ] + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_mean( + math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10), + 2, + keepdims=True) + + +class _QuadraticPlusSinProblemND(_InterpolationProblem): + """3D interpolation problem used for regression testing.""" + + DATA_DIM = 3 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], + (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], + (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], + (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312], + (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491], + (3.0, + 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866], + (4.0, + 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833], + (4.0, + 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382], + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_sum( + math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15), + 2, + keepdims=True) + + +class InterpolateSplineTest(test_util.TensorFlowTestCase): + + def test_1d_linear_interpolation(self): + """For 1d linear interpolation, we can compare directly to scipy.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, train_values) = tp.get_problem( + extrapolate=False, dtype='float64') + interpolation_order = 1 + + with ops.name_scope('interpolator'): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order) + with self.test_session() as sess: + fetches = [query_points, train_points, train_values, interpolator] + query_points_, train_points_, train_values_, interp_ = sess.run(fetches) + + # Just look at the first element of the minibatch. + # Also, trim the final singleton dimension. + interp_ = interp_[0, :, 0] + query_points_ = query_points_[0, :, 0] + train_points_ = train_points_[0, :, 0] + train_values_ = train_values_[0, :, 0] + + # Compute scipy interpolation. + scipy_interp_function = sc_interpolate.interp1d( + train_points_, train_values_, kind='linear') + + scipy_interpolation = scipy_interp_function(query_points_) + scipy_interpolation_on_train = scipy_interp_function(train_points_) + + # Even with float64 precision, the interpolants disagree with scipy a + # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. + tol = 1e-3 + + self.assertAllClose( + train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) + self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol) + + def test_1d_interpolation(self): + """Regression test for interpolation with 1-D points.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_nd_linear_interpolation(self): + """Regression test for interpolation with N-D points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_interpolation_gradient(self): + """Make sure that backprop can run. Correctness of gradients is assumed. + + Here, we create a use a small 'training' set and a more densely-sampled + set of query points, for which we know the true value in advance. The goal + is to choose x locations for the training data such that interpolating using + this training data yields the best reconstruction for the function + values at the query points. The training data locations are optimized + iteratively using gradient descent. + """ + tp = _QuadraticPlusSinProblemND() + (query_points, query_values, train_points, + train_values) = tp.get_problem(optimizable=True) + + regularization = 0.001 + for interpolation_order in (1, 2, 3, 4): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order, + regularization) + + loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator)) + + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [train_points]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [train_points])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(100): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0135c66e293693345c3da7fdb21e28ca6d160154 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import sparse_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + +from tensorflow.python.training import momentum + + +class SparseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def testGetBoundaryLocations(self): + image_height = 11 + image_width = 11 + num_points_per_edge = 4 + locs = sparse_image_warp._get_boundary_locations(image_height, image_width, + num_points_per_edge) + num_points = locs.shape[0] + self.assertEqual(num_points, 4 + 4 * num_points_per_edge) + locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (2, 4, 6, 8): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (0, image_height - 1): + for j in (2, 4, 6, 8): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + def testGetGridLocations(self): + image_height = 5 + image_width = 3 + grid = sparse_image_warp._get_grid_locations(image_height, image_width) + for i in range(image_height): + for j in range(image_width): + self.assertEqual(grid[i, j, 0], i) + self.assertEqual(grid[i, j, 1], j) + + def testZeroShift(self): + """Run assertZeroShift for various hyperparameters.""" + for order in (1, 2): + for regularization in (0, 0.01): + for num_boundary_points in (0, 1): + self.assertZeroShift(order, regularization, num_boundary_points) + + def assertZeroShift(self, order, regularization, num_boundary_points): + """Check that warping with zero displacements doesn't change the image.""" + batch_size = 1 + image_height = 4 + image_width = 4 + channels = 3 + + image = np.random.uniform( + size=[batch_size, image_height, image_width, channels]) + + input_image_op = constant_op.constant(np.float32(image)) + + control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + + control_point_displacements = np.zeros( + control_point_locations.shape.as_list()) + control_point_displacements = constant_op.constant( + np.float32(control_point_displacements)) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + regularization_weight=regularization, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, _ = sess.run( + [warped_image_op, input_image_op, flow_field]) + + self.assertAllClose(warped_image, input_image) + + def testMoveSinglePixel(self): + """Run assertMoveSinglePixel for various hyperparameters and data types.""" + for order in (1, 2): + for num_boundary_points in (1, 2): + for type_to_use in (dtypes.float32, dtypes.float64): + self.assertMoveSinglePixel(order, num_boundary_points, type_to_use) + + def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): + """Move a single block in a small grid using warping.""" + batch_size = 1 + image_height = 7 + image_width = 7 + channels = 3 + + image = np.zeros([batch_size, image_height, image_width, channels]) + image[:, 3, 3, :] = 1.0 + input_image_op = constant_op.constant(image, dtype=type_to_use) + + # Place a control point at the one white pixel. + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0)), + dtype=type_to_use) + # Shift it one pixel to the right. + control_point_displacements = [[0., 1.0]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0)), + dtype=type_to_use) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, flow = sess.run( + [warped_image_op, input_image_op, flow_field]) + # Check that it moved the pixel correctly. + self.assertAllClose( + warped_image[0, 4, 5, :], + input_image[0, 4, 4, :], + atol=1e-5, + rtol=1e-5) + + # Test that there is no flow at the corners. + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertAllClose( + flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) + + def load_image(self, image_file, sess): + image_op = image_ops.decode_png( + io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3] + return sess.run(image_op) + + def testSmileyFace(self): + """Check warping accuracy by comparing to hardcoded warped images.""" + + test_data_dir = test.test_src_dir_path('contrib/image/python/' + 'kernel_tests/test_data/') + input_file = test_data_dir + 'Yellow_Smiley_Face.png' + with self.test_session() as sess: + input_image = self.load_image(input_file, sess) + control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], + [180 - 39, 111], [90, 143], [58, 134], + [180 - 58, 134]]) # pyformat: disable + control_point_displacements = np.asarray( + [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], + [10, 10.75]]) + control_points_op = constant_op.constant( + np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) + control_point_displacements_op = constant_op.constant( + np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) + float_image = np.expand_dims(np.float32(input_image) / 255, 0) + input_image_op = constant_op.constant(float_image) + + for interpolation_order in (1, 2, 3): + for num_boundary_points in (0, 1, 4): + warp_op, _ = sparse_image_warp.sparse_image_warp( + input_image_op, + control_points_op, + control_points_op + control_point_displacements_op, + interpolation_order=interpolation_order, + num_boundary_points=num_boundary_points) + with self.test_session() as sess: + warped_image = sess.run(warp_op) + out_image = np.uint8(warped_image[0, :, :, :] * 255) + target_file = ( + test_data_dir + + 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format( + interpolation_order, num_boundary_points)) + + target_image = self.load_image(target_file, sess) + + # Check that the target_image and out_image difference is no + # bigger than 2 (on a scale of 0-255). Due to differences in + # floating point computation on different devices, the float + # output in warped_image may get rounded to a different int + # than that in the saved png file loaded into target_image. + self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3) + + def testThatBackpropRuns(self): + """Run optimization to ensure that gradients can be computed.""" + + batch_size = 1 + image_height = 9 + image_width = 12 + image = variables.Variable( + np.float32( + np.random.uniform(size=[batch_size, image_height, image_width, 3]))) + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + control_point_displacements = [[0.25, -0.5]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0))) + warped_image, _ = sparse_image_warp.sparse_image_warp( + image, + control_point_locations, + control_point_locations + control_point_displacements, + num_boundary_points=3) + + loss = math_ops.reduce_mean(math_ops.abs(warped_image - image)) + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [image]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [image])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png new file mode 100644 index 0000000000000000000000000000000000000000..7e303881e213a82e412d18de9d9d86f368726f06 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..7fd9e4e6d69f3120428d1d778846d495cea1a989 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..86d225e5d2158804f88dca881f69ed3ab287d866 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..37e8ffae114625d0cc6a07ab2b8dbbb7413a3829 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..e49b5816120d43a669264915f1b6747606e080e0 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..df3cf2004312ed0ed0ebf1f0340cbfec7fd9ac46 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..e1799a87c8542d7e515b6185d7e8f6f75fe73f3e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..2c346e0ce5487e21d41aa4e6306fd83a7b4ffdb4 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..6f8b65451cc08a463e4305ddc4be0dbe2879fae9 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..8e78146d955ae8f02230121e6314f3285e87611e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b219ada492466919c615d8978e462e6c619d33 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using per-pixel flow vectors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _interpolate_bilinear(grid, + query_points, + name='interpolate_bilinear', + indexing='ij'): + """Similar to Matlab's interp2 function. + + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the inputs + invalid. + """ + if indexing != 'ij' and indexing != 'xy': + raise ValueError('Indexing mode must be \'ij\' or \'xy\'') + + with ops.name_scope(name): + grid = ops.convert_to_tensor(grid) + query_points = ops.convert_to_tensor(query_points) + shape = grid.get_shape().as_list() + if len(shape) != 4: + msg = 'Grid must be 4 dimensional. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + batch_size, height, width, channels = shape + query_type = query_points.dtype + grid_type = grid.dtype + + if (len(query_points.get_shape()) != 3 or + query_points.get_shape()[2].value != 2): + msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' + 'size: ') + raise ValueError(msg + str(query_points.get_shape())) + + _, num_queries, _ = query_points.get_shape().as_list() + + if height < 2 or width < 2: + msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + alphas = [] + floors = [] + ceils = [] + + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) + + for dim in index_order: + with ops.name_scope('dim-' + str(dim)): + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type) + min_floor = constant_op.constant(0.0, dtype=query_type) + floor = math_ops.minimum( + math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor) + int_floor = math_ops.cast(floor, dtypes.int32) + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + alpha = math_ops.cast(queries - floor, grid_type) + min_alpha = constant_op.constant(0.0, dtype=grid_type) + max_alpha = constant_op.constant(1.0, dtype=grid_type) + alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = array_ops.expand_dims(alpha, 2) + alphas.append(alpha) + + if batch_size * height * width > np.iinfo(np.int32).max / 8: + error_msg = """The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""" + raise ValueError(error_msg) + + flattened_grid = array_ops.reshape(grid, + [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) + + # This wraps array_ops.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using array_ops.gather_nd. + def gather(y_coords, x_coords, name): + with ops.name_scope('gather-' + name): + linear_coordinates = batch_offsets + y_coords * width + x_coords + gathered_values = array_ops.gather(flattened_grid, linear_coordinates) + return array_ops.reshape(gathered_values, + [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], 'top_left') + top_right = gather(floors[0], ceils[1], 'top_right') + bottom_left = gather(ceils[0], floors[1], 'bottom_left') + bottom_right = gather(ceils[0], ceils[1], 'bottom_right') + + # now, do the actual interpolation + with ops.name_scope('interpolate'): + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp + + +def dense_image_warp(image, flow, name='dense_image_warp'): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a dense + flow field of offset vectors that define the correspondences of pixel values + in the output image back to locations in the source image. Specifically, the + pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + + + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + + Note that image and flow can be of type tf.half, tf.float32, or tf.float64, + and do not necessarily have to be the same type. + + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong number + of dimensions. + """ + with ops.name_scope(name): + batch_size, height, width, channels = image.get_shape().as_list() + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = array_ops.meshgrid( + math_ops.range(width), math_ops.range(height)) + stacked_grid = math_ops.cast( + array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) + batched_grid = array_ops.expand_dims(stacked_grid, axis=0) + query_points_on_grid = batched_grid - flow + query_points_flattened = array_ops.reshape(query_points_on_grid, + [batch_size, height * width, 2]) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = _interpolate_bilinear(image, query_points_flattened) + interpolated = array_ops.reshape(interpolated, + [batch_size, height, width, channels]) + return interpolated diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py new file mode 100644 index 0000000000000000000000000000000000000000..daf8c56456327f102f1409296a91f9f7b68ec799 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py @@ -0,0 +1,291 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) + y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) + + x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile + + return squared_dists + + +def _pairwise_squared_distance_matrix(x): + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) + + Args: + x: `[batch_size, n, d]` float `Tensor` + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 + """ + + x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) + x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( + x_norm_squared_tile, [0, 2, 1]) + + return squared_dists + + +def _solve_interpolation(train_points, train_values, order, + regularization_weight): + """Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the 'training' + data defined by (train_points, train_values) using the kernel phi. + + Args: + train_points: `[b, n, d]` interpolation centers + train_values: `[b, n, k]` function values + order: order of the interpolation + regularization_weight: weight to place on smoothness regularization term + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + """ + + b, n, d = train_points.get_shape().as_list() + _, _, k = train_values.get_shape().as_list() + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with ops.name_scope('construct_linear_system'): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = np.expand_dims(np.eye(n), 0) + batch_identity_matrix = constant_op.constant( + batch_identity_matrix, dtype=train_points.dtype) + + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term in the linear model. + ones = array_ops.ones([b, n, 1], train_points.dtype) + matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = array_ops.concat( + [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = array_ops.concat([matrix_b, lhs_zeros], + 1) # [b, n + d + 1, d + 1] + lhs = array_ops.concat([left_block, right_block], + 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) + rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with ops.name_scope('solve_linear_system'): + w_v = linalg_ops.matrix_solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + + batch_size = train_points.get_shape()[0].value + num_query_points = query_points.get_shape()[1].value + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = math_ops.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = array_ops.concat([ + query_points, + array_ops.ones([batch_size, num_query_points, 1], train_points.dtype) + ], 2) + linear_term = math_ops.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op + order: interpolation order + + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with ops.name_scope('phi'): + if order == 1: + r = math_ops.maximum(r, EPSILON) + r = math_ops.sqrt(r) + return r + elif order == 2: + return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * math_ops.square(r) * math_ops.log( + math_ops.maximum(r, EPSILON)) + elif order % 2 == 0: + r = math_ops.maximum(r, EPSILON) + return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) + else: + r = math_ops.maximum(r, EPSILON) + return math_ops.pow(r, 0.5 * order) + + +def interpolate_spline(train_points, + train_values, + query_points, + order, + regularization_weight=0.0, + name='interpolate_spline'): + r"""Interpolate signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) + terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. + The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function at + the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the + vector w sums to 0. With these constraints, the coefficients can be obtained + by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the training + data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may be + less vulnerable to overfitting, particularly for high-order interpolation. + + Note the interpolation procedure is differentiable with respect to all inputs + besides the order parameter. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values + evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), + or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with ops.name_scope(name): + train_points = ops.convert_to_tensor(train_points) + train_values = ops.convert_to_tensor(train_values) + query_points = ops.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with ops.name_scope('solve'): + w, v = _solve_interpolation(train_points, train_values, order, + regularization_weight) + + # Then, evaluate the spline at the query locations. + with ops.name_scope('predict'): + query_values = _apply_interpolation(query_points, train_points, w, v, + order) + + return query_values diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..54a215d6db6ded56a1a4a018a7e176f35fe6397e --- /dev/null +++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') + return np.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(np_array, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + tiles = [batch_size] + [1] * np_array.ndim + return np.tile(np.expand_dims(np_array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) + x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) + ys, xs = np.meshgrid(y_range, x_range, indexing='ij') + is_boundary = np.logical_or( + np.logical_or(xs == 0, xs == image_width - 1), + np.logical_or(ys == 0, ys == image_height - 1)) + return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = control_point_locations.get_shape()[0].value + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + + boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) + + type_to_use = control_point_locations.dtype + boundary_point_locations = constant_op.constant( + _expand_to_minibatch(boundary_point_locations, batch_size), + dtype=type_to_use) + + boundary_point_flows = constant_op.constant( + _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) + + merged_control_point_locations = array_ops.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = array_ops.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (@{tf.contrib.image.dense_image_warp}). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See @{tf.contrib.image.interpolate_spline} for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + + batch_size, image_height, image_width, _ = image.get_shape().as_list() + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = np.reshape(grid_locations, + [image_height * image_width, 2]) + + flattened_grid_locations = constant_op.constant( + _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 9d6b4d5d87e24d72b29ab33ee805fe0d068cc30a..0e34315db45d61282af1882631dc769a72965c3e 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -114,14 +114,3 @@ tf_cc_tests( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/input_pipeline/kernels/BUILD b/tensorflow/contrib/input_pipeline/kernels/BUILD index f20a6e38d4e80f869e9274d6fc49338a95fc6788..797605b8fe66e8375edcc70668a07a8d2a6d73f3 100644 --- a/tensorflow/contrib/input_pipeline/kernels/BUILD +++ b/tensorflow/contrib/input_pipeline/kernels/BUILD @@ -17,14 +17,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD index 66948c1ea1f3f239d3f43a57626f8c229fe24ad9..0b7d64f4edd7587000ca5b9ecae257fe8fedd4a1 100644 --- a/tensorflow/contrib/integrate/BUILD +++ b/tensorflow/contrib/integrate/BUILD @@ -42,14 +42,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD index 1c3974871c62911c0cb47677eb92d28286837142..3913c9dc7abfba2829bde5e86fe2927e8fc29a9d 100644 --- a/tensorflow/contrib/kafka/BUILD +++ b/tensorflow/contrib/kafka/BUILD @@ -119,17 +119,3 @@ tf_py_test( "notap", ], ) - -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 7e0019ce4ad6c96e09ac9e222e2f4e2840273983..7a4cab20d1a3471af2a2a402a6d1443a90fa7f9b 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -52,15 +52,3 @@ py_library( "//tensorflow/python/keras", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index eff7dfeb4c1117e40f4faf43c5e92a52cffd6528..87c2dcd89b63fa9f92d93c87abce91fd3460d44e 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -90,15 +90,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index f182fef067b7f523bc5ca63227265be40528b171..4ef0a66a52429233c6e6f70667a451466493629c 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -43,10 +43,10 @@ def sparse_multiclass_hinge_loss( This is a generalization of standard (binary) hinge loss. For a given instance with correct label c*, the loss is given by: - loss = max_{c != c*} logits_c - logits_{c*} + 1. + $$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$ or equivalently - loss = max_c { logits_c - logits_{c*} + I_{c != c*} } - where I_{c != c*} = 1 if c != c* and 0 otherwise. + $$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$ + where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise. Args: labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py index 9dc01124ab195ae17b8795a11e4ebefe3f2c746b..091f0a109801065f06110e2a313c24486d38109f 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py @@ -35,23 +35,23 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): The RFFM mapping is used to approximate the Gaussian (RBF) kernel: ``` - exp(-||x-y||_2^2 / (2 * sigma^2)) + $$(exp(-||x-y||_2^2 / (2 * \sigma^2))$$ ``` The implementation of RFFM is based on the following paper: "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht. (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) - The mapping uses a matrix `Omega \in R^{d x D}` and a bias vector `b \in R^D` - where `d` is the input dimension (number of dense input features) and `D` is - the output dimension (i.e., dimension of the feature space the input is mapped - to). Each entry of `Omega` is sampled i.i.d. from a (scaled) Gaussian - distribution and each entry of `b` is sampled independently and uniformly from - [0, 2 * pi]. + The mapping uses a matrix `\\(Omega \in R^{d x D}\\)` and a bias vector + `\\(b \in R^D\\)` where `d` is the input dimension (number of dense input + features) and `D` is the output dimension (i.e., dimension of the feature + space the input is mapped to). Each entry of `Omega` is sampled i.i.d. from a + (scaled) Gaussian distribution and each entry of `b` is sampled independently + and uniformly from [0, \\(2 * pi\\)]. For a single input feature vector x in R^d, its RFFM is defined as: ``` - sqrt(2/D) * cos(x * Omega + b) + $$sqrt(2/D) * cos(x * Omega + b)$$ ``` where `cos` is the element-wise cosine function and `x, b` are represented as row vectors. The aforementioned paper shows that the linear kernel of diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 6f4a264485993ab737723171409042b4a9673669..91929184a2e6f3cccae92cb819501a7c6ef81673 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -34,7 +34,7 @@ def _inner_product(x, y): """Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method - returns x * y^T. + returns \\(x * y^T\\). Args: x: input tensor in row format diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD index 9a5759bf14f753bbc50d3ef8f54ceab7daf745ab..b719046b37ac761d56e8d5aa34772103be691cd6 100644 --- a/tensorflow/contrib/kfac/BUILD +++ b/tensorflow/contrib/kfac/BUILD @@ -24,15 +24,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD index 89965eda374b2b403f680fc77eb923d0e660d1e2..8186fa1c62cb952f86614a96c3965bcddae1686e 100644 --- a/tensorflow/contrib/kfac/examples/BUILD +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -28,8 +28,28 @@ py_library( ) py_binary( - name = "convnet_mnist_main", - srcs = ["convnet_mnist_main.py"], + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], srcs_version = "PY2AND3", deps = [ ":convnet", @@ -58,15 +78,3 @@ py_library( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 39d80addaac1fe855a37255b32bf4412b99df46a..e8e3353091df25e135b1247bf976bb9ce177d1a7 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -37,6 +37,8 @@ import tensorflow as tf from tensorflow.contrib.kfac.examples import mlp from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + lc = tf.contrib.kfac.layer_collection oq = tf.contrib.kfac.op_queue @@ -48,12 +50,18 @@ __all__ = [ "linear_layer", "build_model", "minimize_loss_single_machine", - "minimize_loss_distributed", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", - "train_mnist_distributed", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" ] +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. @@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection): accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the @@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection): def minimize_loss_single_machine(loss, accuracy, layer_collection, + device="/gpu:0", session_config=None): """Minimize loss with K-FAC on a single machine. - A single Session is responsible for running all of K-FAC's ops. + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() + g_step = tf.train.get_or_create_global_step() optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], momentum=0.9) - train_op = optimizer.minimize(loss, global_step=global_step) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) - if global_step_ % 100 == 0: + if (global_step_ + 1) % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks): return int(np.ceil(0.6 * num_tasks)) -def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection): - """Minimize loss with an synchronous implementation of K-FAC. +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. - Different tasks are responsible for different parts of K-FAC's Ops. The first - 60% of tasks update weights; the next 20% accumulate covariance statistics; - the last 20% invert the matrices used to precondition gradients. + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. + + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution untill the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. @@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. @@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, Raises: ValueError: if task_id >= num_worker_tasks. """ - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) - train_op = sync_optimizer.minimize(loss, global_step=global_step) + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + update_op = tf.cond( + tf.equal(tf.mod(global_step + 1, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + else: + update_op = train_op + + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, update_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) tf.logging.info("Starting training.") is_chief = (task_id == 0) @@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = optimizer.cov_update_op + learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): # TODO(duckworthd): Running this op before cov_update_op has been run a # few times can result in "InvalidArgumentError: Cholesky decomposition @@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, return accuracy_ -def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. @@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. - return minimize_loss_single_machine(loss, accuracy, layer_collection) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True): + use_fake_data=True, devices=None): """Train a ConvNet on MNIST. + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 @@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, layer_collection = lc.LayerCollection() tower_results = [] for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): + with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) @@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, accuracy = tf.reduce_mean(accuracies) # Fit model. + session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize_loss_single_machine( - loss, accuracy, layer_collection, session_config=session_config) + allow_soft_placement=False, + device_count=device_count, + ) + + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = optimizer.minimize(loss, global_step=g_step) -def train_mnist_distributed(task_id, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - use_fake_data=False): - """Train a ConvNet on MNIST. + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) + + if (global_step_ + 1) % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covaraiance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") @@ -448,9 +649,17 @@ def train_mnist_distributed(task_id, # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, - master, checkpoint_dir, loss, accuracy, - layer_collection) + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) if __name__ == "__main__": diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c2d4a9e9bfcc4bfb55a25d2f23e66afe5b1375 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py similarity index 57% rename from tensorflow/contrib/kfac/examples/convnet_mnist_main.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py index b0c6fbde198850c76af0bc1600dc23e926227229..4249bf8a8d9d3a5beb87d4140a55b0ee6eadbc64 100644 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -14,44 +14,35 @@ # ============================================================================== r"""Train a ConvNet on MNIST using K-FAC. -See convnet.py for details. +Multi tower training mode. See `convnet.train_mnist_multitower` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import flags import tensorflow as tf from tensorflow.contrib.kfac.examples import convnet -FLAGS = None +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") -def main(argv): - _ = argv - - if FLAGS.num_towers > 1: - convnet.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + tf.app.run(main=main) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py similarity index 63% rename from tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py index e7fcbc65ef379e84a140a06e020549f74f905a99..3aa52aff196fd2699559f80b0c226f470c94b2a3 100644 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -12,23 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step.""" +r"""Train a ConvNet on MNIST using K-FAC. + +Train on single machine. See `convnet.train_mnist_single_machine` for details. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - 'kernel', - 'evolve', - 'proposal_uniform', - 'proposal_normal', -] +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") + + +def main(unused_argv): + convnet.train_mnist_single_gpu(FLAGS.data_dir, num_epochs=200) + -remove_undocumented(__name__, _allowed_symbols) +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD index ce7da95c124beaed4773d68ce0d0c41f187f7c9d..ede7f183fe24f26bd86e232e831dea5f8ea1fdc4 100644 --- a/tensorflow/contrib/kfac/examples/tests/BUILD +++ b/tensorflow/contrib/kfac/examples/tests/BUILD @@ -50,15 +50,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 8d86c2bb5150cd4bc8a2b21ba050e904929e0fe9..6de775cc79953ba548c766e861d6d88e0455a508 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase): def testMinimizeLossSingleMachine(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, - layer_collection) - self.assertLess(accuracy_, 1.0) + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) def testMinimizeLossDistributed(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_distributed( + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", @@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase): loss=loss, accuracy=accuracy, layer_collection=layer_collection) - self.assertLess(accuracy_, 1.0) + self.assertLess(accuracy_, 2.0) def testTrainMnistSingleMachine(self): with tf.Graph().as_default(): @@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase): # but there are too few parameters for the model to effectively memorize # the training set the way an MLP can. convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True) + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") def testTrainMnistMultitower(self): with tf.Graph().as_default(): @@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase): def testTrainMnistDistributed(self): with tf.Graph().as_default(): # Ensure model training doesn't crash. - convnet.train_mnist_distributed( + convnet.train_mnist_distributed_sync_replicas( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", data_dir=None, num_epochs=1, + op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index d1c449402a697dd5f8876c82a6682dde2d18b4df..2477d2bfc12c2df64a672fd457e9634009ccd129 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -156,15 +156,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 30c5404e03910eedb48132b0d69b2eabb89a9149..f22dbcf21566297340f3b4158a810f6d03af12f5 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,30 +39,6 @@ from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] -class DeviceContextGeneratorTest(test.TestCase): - - def testNoDevice(self): - device_context_generator = estimator._DeviceContextGenerator(None) - with ops.device("/device:CPU:0"): # This is what will be used - with device_context_generator(): # Does nothing - a = constant_op.constant([2.0], name="a") - self.assertEqual("/device:CPU:0", a.op.device) - - def testTwoDevices(self): - device_context_generator = estimator._DeviceContextGenerator( - ["/device:GPU:0", "/device:GPU:1"]) - with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes - with device_context_generator(): - a = constant_op.constant([2.0], name="a") - with device_context_generator(): - b = constant_op.constant([2.0], name="b") - with device_context_generator(): - c = constant_op.constant([2.0], name="c") - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual("/device:GPU:1", b.op.device) - self.assertEqual("/device:GPU:0", c.op.device) - - class EstimatorTest(test.TestCase): def setUp(self): @@ -90,68 +65,98 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) est.make_ops_and_vars() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) est.make_ops_and_vars() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="not_a_real_mode") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") est.make_ops_and_vars() def testGradientsModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="gradients") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") est.make_ops_and_vars() def testEmpiricalModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="empirical") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") est.make_ops_and_vars() def testCurvaturePropModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="curvature_prop") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") est.make_ops_and_vars() def testExactModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="exact") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") est.make_ops_and_vars() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, @@ -159,8 +164,8 @@ class EstimatorTest(test.TestCase): # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, cov_update_op_thunks, - _, _) = fisher_estimator.create_ops_and_vars_thunks() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ @@ -198,10 +203,43 @@ class EstimatorTest(test.TestCase): sess.run(cov_update_op) sess.run(increment_global_step) + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_ops, _, inv_update_ops, _, _, + _) = fisher_estimator.make_ops_and_vars(scope="test") + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index b70c700f0936c2d8a2eca6e0836a3ee4ffe4e46d..6eda6c31e34370fd2bea1192ebf777924824c8e3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -63,7 +63,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -72,7 +72,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -81,7 +81,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -91,7 +91,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) block._factor.instantiate_cov_variables() @@ -112,7 +112,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) block._factor.instantiate_cov_variables() @@ -133,7 +133,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) @@ -163,7 +163,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -172,7 +172,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -181,7 +181,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -191,7 +191,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) block._factor.instantiate_cov_variables() @@ -210,7 +210,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) block._factor.instantiate_cov_variables() @@ -228,7 +228,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) @@ -324,8 +324,8 @@ class FullyConnectedDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -376,7 +376,7 @@ class FullyConnectedDiagonalFBTest(test.TestCase): block = fb.FullyConnectedDiagonalFB( lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) block._factor.instantiate_cov_variables() @@ -402,7 +402,7 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. @@ -420,7 +420,7 @@ class EmbeddingKFACFBTest(test.TestCase): # Add some examples. inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) # Instantiate factor's variables. Ensure it doesn't fail. grads = outputs**2. @@ -461,7 +461,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) @@ -471,7 +471,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -482,7 +482,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -493,7 +493,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -525,7 +525,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) block._input_factor.instantiate_cov_variables() @@ -553,7 +553,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. block.instantiate_factors(((grads,),), damping) @@ -689,8 +689,8 @@ class ConvDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result, atol=1e-3) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -751,7 +751,7 @@ class ConvDiagonalFBTest(test.TestCase): block = fb.ConvDiagonalFB( lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) block._factor.instantiate_cov_variables() @@ -775,7 +775,7 @@ class DepthwiseConvKFCBasicFBTest(test.TestCase): layer_collection = lc.LayerCollection() block = fb.DepthwiseConvKFCBasicFB( layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(([grads],), 0.5) @@ -788,7 +788,7 @@ class DepthwiseConvKFCBasicFBTest(test.TestCase): layer_collection = lc.LayerCollection() block = fb.DepthwiseConvKFCBasicFB( layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(([grads],), 0.5) block._input_factor.instantiate_cov_variables() @@ -825,7 +825,7 @@ class ConvKFCBasicFBTest(test.TestCase): outputs = random_ops.random_normal((2, 2, 2)) block = fb.ConvKFCBasicFB( lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) @@ -843,7 +843,7 @@ class ConvKFCBasicFBTest(test.TestCase): outputs = random_ops.random_normal((2, 2, 2, 2)) block = fb.ConvKFCBasicFB( lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) block._input_factor.instantiate_cov_variables() @@ -874,7 +874,7 @@ class ConvKFCBasicFBTest(test.TestCase): outputs = random_ops.random_normal((2, 2, 2, 2)) block = fb.ConvKFCBasicFB( lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -902,7 +902,7 @@ class ConvKFCBasicFBTest(test.TestCase): outputs = random_ops.random_normal((2, 2, 2, 2)) block = fb.ConvKFCBasicFB( lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 block.instantiate_factors(((grads,),), 0.5) @@ -930,7 +930,7 @@ class ConvKFCBasicFBTest(test.TestCase): outputs = array_ops.zeros((2, 2, 2, 2)) block = fb.ConvKFCBasicFB( lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. block.instantiate_factors(((grads,),), damping) @@ -964,7 +964,7 @@ class FullyConnectedSeriesFBTest(test.TestCase): inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) - block.register_additional_minibatch([inputs], [outputs]) + block.register_additional_tower([inputs], [outputs]) self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): @@ -975,7 +975,7 @@ class FullyConnectedSeriesFBTest(test.TestCase): block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), has_bias=True) - block.register_additional_minibatch([inputs], [outputs]) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 block.instantiate_factors((((grads,),),), 0.5) @@ -987,7 +987,7 @@ class FullyConnectedSeriesFBTest(test.TestCase): block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch([inputs], [outputs]) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 block.instantiate_factors((((grads,),),), 0.5) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 16f02f1199ad8a404b0e6944fc89df32ce08609c..2a3592c53fdda488561e504ba2712aadc3214cc4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -85,6 +85,12 @@ class FisherFactorTestingDummy(ff.FisherFactor): def instantiate_inv_variables(self): return NotImplementedError + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -116,6 +122,12 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def instantiate_covariance(self): pass + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class NumericalUtilsTest(test.TestCase): @@ -430,7 +442,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size) + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor.instantiate_cov_variables() cov = factor.get_cov_var() self.assertEqual(cov.shape.as_list(), [vocab_size]) @@ -439,7 +451,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): input_ids = array_ops.constant([[0], [1], [4]]) vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor(input_ids, vocab_size) + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor.instantiate_cov_variables() cov_update_op = factor.make_covariance_update_op(0.0) @@ -477,8 +489,8 @@ class ConvDiagonalFactorTest(test.TestCase): ] factor = ff.ConvDiagonalFactor( - inputs, - outputs_grads, + (inputs,), + (outputs_grads,), self.kernel_shape, self.strides, self.padding, @@ -508,7 +520,8 @@ class ConvDiagonalFactorTest(test.TestCase): self.out_channels) factor = ff.ConvDiagonalFactor( - constant_op.constant(inputs), [constant_op.constant(outputs_grad)], + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), self.kernel_shape, strides=[1, 1, 1, 1], padding='VALID') @@ -537,8 +550,8 @@ class ConvDiagonalFactorTest(test.TestCase): ] factor = ff.ConvDiagonalFactor( - inputs, - outputs_grads, + (inputs,), + (outputs_grads,), self.kernel_shape, self.strides, self.padding, @@ -569,7 +582,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) @@ -587,7 +600,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -598,7 +611,7 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,)) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -629,8 +642,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): out_channels = 4 factor = ff.ConvInputKroneckerFactor( - inputs=random_ops.random_uniform( - (batch_size, width, width, width, in_channels), seed=0), + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), filter_shape=(width, width, width, in_channels, out_channels), padding='SAME', strides=(2, 2, 2), @@ -661,8 +674,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): out_channels = 4 factor = ff.ConvInputKroneckerFactor( - inputs=random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0), + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), filter_shape=(1, 1, in_channels, out_channels), padding='SAME', strides=(1, 1, 1, 1), @@ -691,8 +704,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): out_channels = 4 factor = ff.ConvInputKroneckerFactor( - inputs=random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0), + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), filter_shape=(1, 1, in_channels, out_channels), padding='SAME', strides=(1, 2, 1, 1), @@ -716,8 +729,8 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): out_channels = 4 factor = ff.ConvInputKroneckerFactor( - inputs=random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0), + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), filter_shape=(3, 3, in_channels, out_channels), padding='SAME', extract_patches_fn='extract_image_patches', @@ -739,7 +752,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with tf_ops.Graph().as_default(): tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - inputs=tensor, + inputs=(tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=False) @@ -751,7 +764,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with tf_ops.Graph().as_default(): tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) @@ -761,7 +774,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): dtype = dtypes.float64_ref tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor( - tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) @@ -775,7 +788,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -794,7 +807,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, filter_shape=(1, 1, 1, 1), padding='SAME') + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -810,10 +823,10 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase): width = 3 out_channels = width**3 - factor = ff.ConvOutputKroneckerFactor(outputs_grads=[ + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ random_ops.random_uniform( (batch_size, width, width, width, out_channels), seed=0) - ]) + ],)) factor.instantiate_cov_variables() with self.test_session() as sess: @@ -829,7 +842,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) factor.instantiate_cov_variables() self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) @@ -838,7 +851,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) @@ -848,7 +861,7 @@ class ConvOutputKroneckerFactorTest(ConvFactorTestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),)) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -862,8 +875,7 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) factor.instantiate_cov_variables() self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) @@ -872,8 +884,7 @@ class FullyConnectedMultiKFTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) @@ -883,8 +894,7 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) @@ -895,8 +905,7 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,)) + factor = ff.FullyConnectedMultiKF(((tensor,),)) factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index bae6bd7a3bd47bc50378afe95d26d57535377f6f..cb80fca3705308f92e308e2a840336fb72d0fa62 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class MockFisherBlock(object): """A fake FisherBlock.""" - num_registered_minibatches = 2 + num_registered_towers = 2 def __init__(self, name='MockFisherBlock'): self.name = name @@ -135,8 +135,22 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) - self.assertEqual(9, len(lc.get_blocks())) + self.assertEqual(12, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -454,13 +468,13 @@ class LayerCollectionTest(test.TestCase): b = variable_scope.get_variable('b', [3]) lc = layer_collection.LayerCollection() lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) with self.assertRaises(KeyError): lc.register_fully_connected((w, b), inputs, outputs, reuse=True) self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) def testMakeOrGetFactor(self): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index c26230c2a82ae9529ab13b523b9ec287d17debaf..b897fd68a080e819042cd36f2a1acfcf175e656b 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -171,6 +171,7 @@ py_library( name = "fisher_estimator", srcs = [ "estimator.py", + "placement.py", ], srcs_version = "PY2AND3", deps = [ @@ -180,6 +181,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -242,15 +244,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 64755be65c4b5686397dbfd798fec1ed70ae61dc..ced1110676754b6c8bba813ace743b3f3daddb26 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import itertools - +import abc import numpy as np +import six +from tensorflow.contrib.kfac.python.ops import placement from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops @@ -31,63 +31,46 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -class _DeviceContextGenerator(object): - """Class for generating device contexts in a round-robin fashion.""" - - def __init__(self, devices): - """Creates a _DeviceContextGenerator object. - - Example usage: +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. - ```python - dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) - with dcg(): - # All operations in this context will be placed on GPU 0 - ... - with dcg(): - # All operations in this context will be placed on GPU 1 - ... - ``` - - Args: - devices: An iterable of device strings (or None). Successive calls to - __call__ will give contexts which place devices on these devices in - a round-robin fashion. - """ - self._cycle = None if devices is None else itertools.cycle(devices) + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. - @contextlib.contextmanager - def __call__(self): - """Returns a context manager specifying the default device.""" - if self._cycle is None: - yield - else: - with tf_ops.device(next(self._cycle)): - yield + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. -def _make_thunk_on_device(func, device): - def thunk(): - with tf_ops.device(device): - return func() - return thunk + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops placement strategy : %s", + placement_strategy) +# pylint: enable=abstract-class-instantiated +@six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. - Attributes: - cov_update_thunks: list of no-arg functions. Executing a function adds - covariance update ops for a single FisherFactor to the graph. - cov_update_ops: List of Ops. Running an op updates covariance matrices for a - single FisherFactor. - cov_update_op: Op. Running updates covariance matrices for all - FisherFactors. - inv_update_thunks: list of no-arg functions. Executing a function adds - inverse update ops for a single FisherFactor to the graph. - inv_update_ops: List of Ops. Running an op updates inverse matrices for a - single FisherFactor. - inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. """ def __init__(self, @@ -184,6 +167,77 @@ class FisherEstimator(object): def name(self): return self._name + @abc.abstractmethod + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. For example in case of + round robin placement a new device is chosen for each factor by cycling + through list of devices in the cov_devices argument. If cov_devices is None + then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + pass + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. @@ -286,158 +340,6 @@ class FisherEstimator(object): self._instantiate_factors() self._register_matrix_functions() - def make_ops_and_vars(self, scope=None): - """Make ops and vars with no specific device placement. - - See make_ops_and_vars_round_robin for further details. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - return self.make_ops_and_vars_round_robin(scope=scope) - - # TODO(b/70674513): Factor device placement outside of this class. - def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None, - inv_devices=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - (cov_update_thunks, - inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin( - scope=scope, - cov_devices=cov_devices, - inv_devices=inv_devices) - cov_update_ops = [thunk() for thunk in cov_update_thunks] - inv_update_ops = [thunk() for thunk in inv_update_thunks] - - scope = self.name if scope is None else scope - with variable_scope.variable_scope(scope): - cov_update_op = control_flow_ops.group(cov_update_ops, - name="cov_update_op") - inv_update_op = control_flow_ops.group(inv_update_ops, - name="inv_update_op") - - return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, - cov_update_thunks, inv_update_thunks) - - def make_vars_and_create_op_thunks_round_robin(self, - scope=None, - cov_devices=None, - inv_devices=None): - """Make vars and create op thunks w/ a round-robin device placement strat. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - - (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, - inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) - - if cov_devices: - cov_update_thunks = [] - for cov_variable_thunk, cov_update_thunk, device in zip( - cov_variable_thunks_raw, cov_update_thunks_raw, - itertools.cycle(cov_devices)): - with tf_ops.device(device): - cov_variable_thunk() - cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, - device)) - else: - for cov_variable_thunk in cov_variable_thunks_raw: - cov_variable_thunk() - cov_update_thunks = cov_update_thunks_raw - - for inv_variable_thunk in inv_variable_thunks_raw: - inv_variable_thunk() - - if inv_devices: - inv_update_thunks = [] - for inv_update_thunk, device in zip(inv_update_thunks_raw, - itertools.cycle(inv_devices)): - inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, - device)) - else: - inv_update_thunks = inv_update_thunks_raw - - return cov_update_thunks, inv_update_thunks - def create_ops_and_vars_thunks(self, scope=None): """Create thunks that make the ops and vars on demand. @@ -582,3 +484,9 @@ class FisherEstimator(object): colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 31f4689fbfbbf13872c237913a37478f3c2debe0..00b3673a742e92057b0a1673d3f42a19379111fe 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its Fisher Information matrix is given by, - F(params) = E[ v(x, y, params) v(x, y, params)^T ] + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ where, - v(x, y, params) = (d / d params) log p(y | x, params) + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ and the expectation is taken with respect to the data's distribution for 'x' and the model's posterior distribution for 'y', @@ -48,6 +48,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, @@ -84,7 +85,7 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: @@ -159,7 +160,7 @@ class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. Subclasses must implement register_matpower, multiply_matpower, - instantiate_factors, tensors_to_compute_grads, and num_registered_minibatches + instantiate_factors, tensors_to_compute_grads, and num_registered_towers methods. """ @@ -234,8 +235,8 @@ class FisherBlock(object): pass @abc.abstractproperty - def num_registered_minibatches(self): - """Number of minibatches registered for this FisherBlock. + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. Typically equal to the number of towers in a multi-tower setup. """ @@ -287,8 +288,8 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -296,7 +297,7 @@ class FullFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -349,8 +350,8 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -358,7 +359,7 @@ class NaiveDiagonalFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -366,24 +367,78 @@ class NaiveDiagonalFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class InputOutputMultiMinibatch(object): +class InputOutputMultiTower(object): """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" def __init__(self, *args, **kwargs): self.__inputs = [] self.__outputs = [] - super(InputOutputMultiMinibatch, self).__init__(*args, **kwargs) + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list def tensors_to_compute_grads(self): """Tensors to compute derivative of loss with respect to.""" - return self._outputs + return tuple(self._outputs) - def register_additional_minibatch(self, inputs, outputs): + def register_additional_tower(self, inputs, outputs): self._inputs.append(inputs) self._outputs.append(outputs) @property - def num_registered_minibatches(self): + def num_registered_towers(self): result = len(self._inputs) assert result == len(self._outputs) return result @@ -396,59 +451,8 @@ class InputOutputMultiMinibatch(object): def _outputs(self): return self.__outputs - def _package_minibatches(self, grads_list): - """Constructs PartitionedTensor for inputs, grads_list. - - The purpose of this method is to package up the towers/minibatch dimension - of these arrays into PartitionedTensor objects. - - Args: - grads_list: 2-D list of Tensors. First index is for source, second - index for tower. - - Returns: - inputs: PartitionedTensor. - grads_list: Tuple of PartitionedTensors, one per source. - """ - inputs = utils.PartitionedTensor(self._inputs) - grads_list = tuple(utils.PartitionedTensor(grads) for grads in grads_list) - - return inputs, grads_list - def _package_minibatches_multi(self, grads_list): - """Constructs PartitionedTensors for inputs, grads_list. - - The purpose of this method is to package up the towers/minibatch dimension - of these arrays into PartitionedTensor objects. - - This version of this function is for use with FisherBlocks that deal with - multiple uses or time-steps. One PartitionedTensor is created for each - use/time-step. - - Args: - grads_list: 3-D tuple of Tensors. First index is for source, second - index is for tower, third is for use/time-step. - - Returns: - inputs: A tuple of PartitionedTensor's, one per use/time-step. - grads_list: 2-D tuple of PartitionedTensors. First index is for source, - second is for use/time-step. - """ - # self._inputs is a 2-D tuple. First index is tower/mini-batch, second is - # use/time-step. - inputs = self._inputs - num_uses = len(inputs[0]) - assert all(len(input_) == num_uses for input_ in inputs) - assert all(len(grad) == num_uses for grads in grads_list for grad in grads) - - inputs = tuple(utils.PartitionedTensor(input_) for input_ in zip(*inputs)) - grads_list = tuple(tuple(utils.PartitionedTensor(grad) - for grad in zip(*grads)) for grads in grads_list) - - return inputs, grads_list - - -class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): +class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -458,14 +462,14 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( a (d loss / d s)^T ) + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. @@ -485,7 +489,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) + inputs, grads_list = self._process_data(grads_list) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, @@ -518,7 +522,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock): return utils.mat2d_to_layer_params(vector, reshaped_out) -class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): +class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional @@ -528,14 +532,14 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider a convoluational layer in this model with (unshared) filter matrix 'w'. For an example image 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ where 'loc' is a single (x, y) location in an image. @@ -598,10 +602,10 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._factor = self._layer_collection.make_or_get_factor( @@ -630,7 +634,7 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock): class KroneckerProductFB(FisherBlock): - """A base class for FisherBlocks with separate input and output factors. + """A base class for blocks with separate input and output Kronecker factors. The Fisher block is approximated as a Kronecker product of the input and output factors. @@ -708,10 +712,10 @@ class KroneckerProductFB(FisherBlock): right_factor) -class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB): +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for embedding layers. - This FisherBlock is similar to EmbeddingKFACFB, except that its + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its input factor is approximated by a diagonal matrix. In the case that each example references exactly one embedding, this approximation is exact. @@ -740,18 +744,17 @@ class EmbeddingKFACFB(InputOutputMultiMinibatch, KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - inputs, grads_list = self._package_minibatches(grads_list) + inputs, grads_list = self._process_data(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.EmbeddingInputKroneckerFactor, # + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # - (grads_list,)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) self._setup_damping(damping) -class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. This uses the Kronecker-factorized approximation from the original @@ -781,18 +784,18 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - inputs, grads_list = self._package_minibatches(grads_list) + inputs, grads_list = self._process_data(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) self._setup_damping(damping) -class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): """FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional @@ -802,12 +805,12 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, - F(w) = #locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T]) + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ where - ds = (d / ds) log p(y | x, w) + $$ds = (d / ds) log p(y | x, w)$$ #locations = number of (x, y) locations where 'w' is applied. where the expectation is taken over all examples and locations and flat() @@ -858,10 +861,10 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._package_minibatches(grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor( @@ -1137,42 +1140,327 @@ def num_conv_locations(input_shape, strides): return spatial_input_locations // spatial_strides_divisor -class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB): +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" + + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb """ - def __init__(self, layer_collection, has_bias=False): + def __init__(self, layer_collection, has_bias=False, num_uses=None): """Creates a FullyConnectedMultiIndepFB block. Args: layer_collection: LayerCollection instance. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) """ self._has_bias = has_bias - super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) def instantiate_factors(self, grads_list, damping): - - self._num_uses = float(len(self._inputs[0])) - inputs, grads_list = self._package_minibatches_multi(grads_list) + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._has_bias)) + ((inputs,), self._num_uses, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) self._setup_damping(damping, normalization=self._num_uses) @property def _renorm_coeff(self): - return self._num_uses + return float(self._num_uses) - def tensors_to_compute_grads(self): - return self._outputs + +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) class SeriesFBApproximation(enum.IntEnum): @@ -1181,10 +1469,12 @@ class SeriesFBApproximation(enum.IntEnum): option2 = 2 -class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters across time. - See the following preprint for details: + This class implements the "Option 1" and "Option 2" approximation from the + following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the @@ -1196,6 +1486,7 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): def __init__(self, layer_collection, has_bias=False, + num_uses=None, option=SeriesFBApproximation.option2): """Constructs a new `FullyConnectedSeriesFB`. @@ -1203,6 +1494,10 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) option: A `SeriesFBApproximation` specifying the simplifying assumption to be used in this block. `option1` approximates the cross-covariance over time as a symmetric matrix, while `option2` makes @@ -1213,36 +1508,33 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): self._has_bias = has_bias self._option = option - super(FullyConnectedSeriesFB, self).__init__(layer_collection) + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) - def instantiate_factors(self, grads_list, damping): + @property + def _num_timesteps(self): + return self._num_uses + + @property + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False - self._num_timesteps = len(self._inputs[0]) - inputs, grads_list = self._package_minibatches_multi(grads_list) + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, ((inputs,), self._has_bias)) + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) self._input_factor.register_cov_dt1() self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) self._output_factor.register_cov_dt1() - def compute_damping(): - normalized_damping = normalize_damping(damping, self._num_timesteps) - return compute_pi_adjusted_damping(self._input_factor.get_cov(), - self._output_factor.get_cov(), - normalized_damping**0.5) - - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - "normalize_damping", - damping, self._num_timesteps, "power", 0.5) - self._input_damping_func = _package_func(lambda: compute_damping()[0], - damping_id + ("ref", 0)) - self._output_damping_func = _package_func(lambda: compute_damping()[1], - damping_id + ("ref", 1)) + self._setup_damping(damping, normalization=self._num_uses) def register_matpower(self, exp): if exp != -1: @@ -1275,7 +1567,7 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): if self._option == SeriesFBApproximation.option1: - # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) L_A, psi_A = self._input_factor.get_option1quants( self._input_damping_func) L_G, psi_G = self._output_factor.get_option1quants( @@ -1289,33 +1581,33 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): T = self._num_timesteps return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) # Even though Y is Z-independent we are recomputing it from the psi's # each since Y depends on both A and G quantities, and it is relatively # cheap to compute. Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - # Z = L_G^T * Z * L_A + # \\(Z = L_G^T * Z * L_A\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = U_G^T * Z * U_A + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - # Z = Z .* Y + # \\(Z = Z .* Y\\) Z *= Y - # Z = L_G * Z * L_A^T + # \\(Z = L_G * Z * L_A^T\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = U_G * Z * U_A^T - # Z = G0^(-1/2) * Z * A0^(-1/2) + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) elif self._option == SeriesFBApproximation.option2: - # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), - # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) P_A, K_A, mu_A = self._input_factor.get_option2quants( self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( @@ -1324,26 +1616,26 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. # In particular, the first three computations in the pseudo code are - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = Z - hPsi_G^T * Z * hPsi_A - # Z = E_G^T * Z * E_A - # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that - # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) # the entire computation can be written as - # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A - # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A - # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A - # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A - # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) # This final expression is computed by the following two lines: - # Z = Z - P_G * Z * P_A^T + # \\(Z = Z - P_G * Z * P_A^T\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # Z = K_G^T * Z * K_A + # \\(Z = K_G^T * Z * K_A\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) # Be careful with the outer product. We don't want to accidentally # make it an inner-product instead. tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A @@ -1354,13 +1646,13 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): # We now perform the transpose/reverse version of the operations # derived above, whose derivation from the original pseudo-code is # analgous. - # Z = K_G * Z * K_A^T + # \\(Z = K_G * Z * K_A^T\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - # Z = Z - P_G^T * Z * P_A + # \\(Z = Z - P_G^T * Z * P_A\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - # Z = normalize (1/E[T]) * Z + # \\(Z = normalize (1/E[T]) * Z\\) # Note that this normalization is done because we compute the statistics # by averaging, not summing, over time. (And the gradient is presumably # summed over time, not averaged, and thus their scales are different.) @@ -1372,6 +1664,3 @@ class FullyConnectedSeriesFB(InputOutputMultiMinibatch, FisherBlock): return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name - - def tensors_to_compute_grads(self): - return self._outputs diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 6fc163e2323666aca8489bf146ebc8582995cf06..0d40d265a1727075d0ba721b0d9a756c38269a96 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import contextlib import numpy as np import six @@ -37,6 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.util import nest + # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). INIT_COVARIANCES_AT_ZERO = False @@ -53,16 +55,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" + def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None): + eigenvalue_clipping_threshold=None, + tower_strategy=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD + global TOWER_STRATEGY if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -72,6 +83,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -90,6 +103,15 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield + + def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. @@ -256,6 +278,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _num_towers(self): + pass + @abc.abstractproperty def _dtype(self): """dtype for variable backing this factor.""" @@ -278,12 +304,14 @@ class FisherFactor(object): dtype=self._dtype) @abc.abstractmethod - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): """Computes minibatch-estimated covariance for a single source. Args: - idx: int in [0, self._num_sources). Which source to use when estimating - covariance. + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. Returns: Tensor of same shape as self.get_cov_var(). @@ -298,15 +326,33 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov_contribs = tuple(self._compute_new_cov(idx) - for idx in range(self._num_sources)) - new_cov = math_ops.add_n(new_cov_contribs) - # Synchronize value across all TPU cores. + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. if utils.on_tpu(): new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod + def _get_data_device(self, tower): + pass + @abc.abstractmethod def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" @@ -597,17 +643,26 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): + assert tower == 0 + # This will be a very basic rank 1 estimate - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) return ((params_grads_flat * array_ops.transpose( params_grads_flat)) / math_ops.cast(self._batch_size, params_grads_flat.dtype)) + def _get_data_device(self, tower): + return None + class DiagonalFactor(FisherFactor): """A base class for FisherFactors that use diagonal approximations. @@ -692,15 +747,24 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) return (math_ops.square(params_grads_flat) / math_ops.cast( self._batch_size, params_grads_flat.dtype)) + def _get_data_device(self, tower): + return None + class EmbeddingInputKroneckerFactor(DiagonalFactor): r"""FisherFactor for input to an embedding layer. @@ -720,8 +784,8 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): """Instantiate EmbeddingInputKroneckerFactor. Args: - input_ids: Tensor of shape [batch_size, input_size] and dtype int32. - Indices into embedding matrix. + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. dtype: dtype for covariance statistics. Must be a floating point type. Defaults to float32. @@ -744,15 +808,18 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): def _num_sources(self): return 1 + @property + def _num_towers(self): + return len(self._input_ids) + @property def _dtype(self): return self._cov_dtype - def _compute_new_cov(self, idx=0): - if idx != 0: - raise ValueError("EmbeddingInputKroneckerFactor only supports idx = 0") + def _compute_new_cov(self, source, tower): + assert source == 0 - input_ids = self._input_ids + input_ids = self._input_ids[tower] if len(input_ids.shape) > 2: raise ValueError( @@ -782,6 +849,9 @@ class EmbeddingInputKroneckerFactor(DiagonalFactor): return new_cov + def _get_data_device(self, tower): + return self._input_ids[tower].device + class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. @@ -801,10 +871,11 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): """Instantiate FullyConnectedDiagonalFactor. Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to this layer. + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, output_size], which are the gradients of the loss with respect to the layer's - outputs. One Tensor for each "source". + outputs. First index is source, second is tower. has_bias: bool. If True, append '1' to each input. """ @@ -818,47 +889,58 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): @property def _var_scope(self): return "ff_diagfc_" + scope_string_from_params( - (self._inputs,) + tuple(self._outputs_grads)) + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): - input_size = self._inputs.shape[1] + self._has_bias - output_size = self._outputs_grads[0].shape[1] + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._outputs_grads[0][0].dtype def make_covariance_update_op(self, ema_decay): - inputs = self._inputs - if self._has_bias: - inputs = append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( ema_decay) - def _compute_new_cov(self, idx=0): - batch_size = array_ops.shape(self._squared_inputs)[0] - outputs_grad = self._outputs_grads[idx] + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. new_cov = math_ops.matmul( - self._squared_inputs, + self._squared_inputs[tower], math_ops.square(outputs_grad), transpose_a=True) new_cov /= math_ops.cast(batch_size, new_cov.dtype) return new_cov + def _get_data_device(self, tower): + return self._inputs[tower].device + class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" @@ -875,11 +957,12 @@ class ConvDiagonalFactor(DiagonalFactor): """Creates a ConvDiagonalFactor object. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, height, width, out_channels], which are the gradients of the loss - with respect to the layer's outputs. One Tensor for each "source". + with respect to the layer's outputs. First index is source, second + index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). @@ -897,14 +980,15 @@ class ConvDiagonalFactor(DiagonalFactor): """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") - if inputs.shape.ndims != 4: - raise ValueError("inputs must be 4-D Tensor.") - if inputs.shape.as_list()[-1] != filter_shape[-2]: + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): raise ValueError("inputs and filter_shape must agree on in_channels.") for i, outputs_grad in enumerate(outputs_grads): - if outputs_grad.shape.ndims != 4: + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): raise ValueError("outputs[%d] must be 4-D Tensor." % i) - if outputs_grad.shape.as_list()[-1] != filter_shape[-1]: + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): raise ValueError( "outputs[%d] and filter_shape must agree on out_channels." % i) if len(strides) != 4: @@ -927,7 +1011,7 @@ class ConvDiagonalFactor(DiagonalFactor): @property def _var_scope(self): return "ff_convdiag_" + scope_string_from_params( - (self._inputs,) + tuple(self._outputs_grads)) + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): @@ -941,9 +1025,13 @@ class ConvDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._inputs[0].dtype def make_covariance_update_op(self, ema_decay): filter_height, filter_width, _, _ = self._filter_shape @@ -954,25 +1042,30 @@ class ConvDiagonalFactor(DiagonalFactor): rates = (1, 1, 1, 1) else: rates = tuple(self._dilations) - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=rates, - padding=self._padding) - if self._has_bias: - patches = append_homog(patches) + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) - self._patches = patches + self._patches.append(patches) return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - def _compute_new_cov(self, idx=0): - batch_size = array_ops.shape(self._patches)[0] - outputs_grad = self._outputs_grads[idx] + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) new_cov /= math_ops.cast(batch_size, new_cov.dtype) return new_cov @@ -985,6 +1078,9 @@ class ConvDiagonalFactor(DiagonalFactor): outputs_grad) return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + def _get_data_device(self, tower): + return self._inputs[tower].device + class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. @@ -996,9 +1092,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Instantiate FullyConnectedKroneckerFactor. Args: - tensors: List of Tensors, each of shape [batch_size, n], one for each - source. The Tensors are typically either a layer's inputs or its - output's gradients. + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of @@ -1010,27 +1106,34 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): @property def _var_scope(self): return "ff_fckron_" + scope_string_from_params( - tuple(self._tensors) + (self._has_bias,)) + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): - size = self._tensors[0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) + @property + def _num_towers(self): + return len(self._tensors[0]) + @property def _dtype(self): - return self._tensors[0].dtype + return self._tensors[0][0].dtype - def _compute_new_cov(self, idx=0): - tensor = self._tensors[idx] + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] if self._has_bias: tensor = append_homog(tensor) return compute_cov(tensor) + def _get_data_device(self, tower): + return self._tensors[0][tower].device + class ConvInputKroneckerFactor(InverseProvidingFactor): r"""Kronecker factor for the input side of a convolutional layer. @@ -1054,8 +1157,8 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): """Initializes ConvInputKroneckerFactor. Args: - inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels]. - Inputs to layer. + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". @@ -1084,10 +1187,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_convinkron_" + scope_string_from_params([ - self._inputs, self._filter_shape, self._strides, self._padding, - self._dilation_rate, self._data_format, self._has_bias - ]) + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) @property def _cov_shape(self): @@ -1100,19 +1203,24 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._inputs.dtype + return self._inputs[0].dtype - def _compute_new_cov(self, idx=0): - if idx != 0: - raise ValueError("ConvInputKroneckerFactor only supports idx = 0") + def _compute_new_cov(self, source, tower): + assert source == 0 + + inputs = self._inputs[tower] # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. if self._extract_patches_fn in [None, "extract_convolution_patches"]: patches = utils.extract_convolution_patches( - self._inputs, + inputs, self._filter_shape, padding=self._padding, strides=self._strides, @@ -1120,7 +1228,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): data_format=self._data_format) elif self._extract_patches_fn == "extract_image_patches": - assert self._inputs.shape.ndims == 4 + assert inputs.shape.ndims == 4 assert len(self._filter_shape) == 4 assert len(self._strides) == 4, self._strides if self._dilation_rate is None: @@ -1130,7 +1238,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): assert len(rates) == 4 assert rates[0] == rates[-1] == 1 patches = array_ops.extract_image_patches( - self._inputs, + inputs, ksizes=[1] + list(self._filter_shape[0:-2]) + [1], strides=self._strides, rates=rates, @@ -1140,7 +1248,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] assert self._filter_shape[0] == self._filter_shape[1] == 1 patches = utils.extract_pointwise_conv2d_patches( - self._inputs, self._filter_shape, data_format=None) + inputs, self._filter_shape, data_format=None) else: raise NotImplementedError(self._extract_patches_fn) @@ -1165,6 +1273,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # (Tilde omitted over A for clarity.) return compute_cov(patches_flat) + def _get_data_device(self, tower): + return self._inputs[tower].device + class ConvOutputKroneckerFactor(InverseProvidingFactor): r"""Kronecker factor for the output side of a convolutional layer. @@ -1181,9 +1292,9 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): """Initializes ConvOutputKroneckerFactor. Args: - outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, ..spatial_input_size.., out_channels]. One Tensor per - source. + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. data_format: None or str. Format of outputs_grads. Raises: @@ -1191,13 +1302,14 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") - self._out_channels = outputs_grads[0].shape.as_list()[-1] + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convoutkron_" + scope_string_from_params(self._outputs_grads) + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) @property def _cov_shape(self): @@ -1208,12 +1320,16 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._outputs_grads[0][0].dtype - def _compute_new_cov(self, idx=0): - outputs_grad = self._outputs_grads[idx] + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] # reshaped_tensor below is the matrix DS_l defined in the KFC paper # (tilde omitted over S for clarity). It has shape M|T| x I, where @@ -1226,28 +1342,30 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): # (Tilde omitted over S for clarity.) return compute_cov(reshaped_tensor) + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device -class FullyConnectedMultiKF(InverseProvidingFactor): + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, - tensor_lists, + tensors, + num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: - tensor_lists: 2D array (list of lists) of Tensors of shape - [batch_size, n]. Each of these tensors is usually a layer's inputs or - its output's gradients. The first dimension of the array is the source, - and the second is the use in the graph (which is sometimes a - "time-step"). + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ - self._tensor_lists = tensor_lists - self._has_bias = has_bias - self._num_timesteps = len(tensor_lists[0]) - self._tensors = [None] * len(tensor_lists) + self._num_uses = num_uses self._cov_dt1 = None self._make_cov_dt1 = False @@ -1256,29 +1374,38 @@ class FullyConnectedMultiKF(InverseProvidingFactor): self._option1quants_registrations = set() self._option2quants_registrations = set() - super(FullyConnectedMultiKF, self).__init__() - - @property - def _var_scope(self): - return "ff_fc_multi_" + scope_string_from_params( - tuple(nest.flatten(self._tensor_lists)) + (self._has_bias,)) + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) @property - def _num_sources(self): - return len(self._tensor_lists) + def _num_timesteps(self): + return self._num_uses @property - def _dtype(self): - return self._tensor_lists[0][0].dtype + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) def make_covariance_update_op(self, ema_decay): op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) if self._cov_dt1 is not None: - new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) - for idx in range(self._num_sources)) - new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) + + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) + + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) + op2 = moving_averages.assign_moving_average( self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) @@ -1291,36 +1418,31 @@ class FullyConnectedMultiKF(InverseProvidingFactor): return op - def _compute_new_cov(self, idx=0): - # Concatenate across time/replications - tensor = array_ops.concat(self._tensor_lists[idx], 0) + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) tensor = append_homog(tensor) - # We save these so they can be used by _compute_new_cov_dt1 - self._tensors[idx] = tensor - return compute_cov(tensor) - def _compute_new_cov_dt1(self, idx=0): # pylint: disable=missing-docstring - tensor = self._tensors[idx] - batch_size = array_ops.shape(self._tensor_lists[idx][0])[0] - # Is there a more elegant way to do this computation? + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps + tensor_present = tensor[:-batch_size, :] tensor_future = tensor[batch_size:, :] + # We specify a normalizer for this computation to ensure a PSD Fisher # block estimate. This is equivalent to padding with zeros, as was done # in Section B.2 of the appendix. - normalizer = self._num_timesteps * batch_size return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=normalizer) + tensor_future, tensor_right=tensor_present, normalizer=total_len) - @property - def _cov_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias - return [size, size] + def _get_data_device(self, tower): + return self._tensors[0][tower].device @property def _vec_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size] def get_option1quants(self, damping_func): diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 4eb5e4c092b50ff4a908a22312330c40ca93cbee..19608aca4716a08ec9f9bea35d07de3a434bbe3f 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -60,6 +60,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" @@ -72,6 +76,14 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { option=2) } +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" @@ -169,9 +181,12 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_SERIES_2_NAME) + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME self.loss_colocation_ops = {} self._vars_to_uses = defaultdict(lambda: 0) @@ -245,14 +260,14 @@ class LayerCollection(object): @property def default_conv2d_approximation(self): - return self._default_convolution_2d_approximation + return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) - self._default_convolution_2d_approximation = value + self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): @@ -264,6 +279,14 @@ class LayerCollection(object): "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -367,7 +390,7 @@ class LayerCollection(object): if name in self._loss_dict: raise KeyError( "Loss function named {} already exists. Set reuse=True to append " - "another minibatch/tower.".format(name)) + "another tower.".format(name)) loss_list = [] self._loss_dict[name] = loss_list @@ -526,45 +549,54 @@ class LayerCollection(object): else: return None + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + def register_embedding(self, params, inputs, outputs, approx=None, reuse=VARIABLE_SCOPE): - """Registers a fully connnected layer. + """Registers an embedding layer. Args: params: Embedding matrix of shape [vocab_size, embedding_size]. inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. - outputs: Tensor of shape [batch_size, output_size]. Outputs + outputs: Tensor of shape [batch_size, embedding_size]. Outputs produced by layer. - approx: str. Must be "kron". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_embedding_approximation - - if approx != APPROX_KRONECKER_NAME: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) if isinstance(params, (tuple, list)): raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) block = self.register_block( - params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) @@ -583,30 +615,29 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_approximation - if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) @@ -635,10 +666,14 @@ class LayerCollection(object): Output produced by layer. data_format: str or None. Format of data. dilations: List of 4 ints. Dilations along each dimension. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. @@ -646,15 +681,14 @@ class LayerCollection(object): ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_conv2d_approximation + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) - if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can't use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? if approx == APPROX_KRONECKER_NAME: block = self.register_block( params, @@ -680,9 +714,9 @@ class LayerCollection(object): data_format=data_format), reuse=reuse) else: - raise NotImplementedError + raise NotImplementedError(approx) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) @@ -712,16 +746,22 @@ class LayerCollection(object): dilation_rate: List of ints of length len(..input_spatial_size..). Dilations along spatial dimension. data_format: str or None. Format of data. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? assert approx is None or approx == APPROX_KRONECKER_NAME block = self.register_block( @@ -734,7 +774,7 @@ class LayerCollection(object): dilation_rate=dilation_rate, data_format=data_format), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) @@ -762,16 +802,21 @@ class LayerCollection(object): rate: None or List of ints of length 2. Dilation rates in spatial dimensions. data_format: str or None. Format of data. - approx: None or str. Must be "diagonal" if non-None. - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? assert approx is None or approx == APPROX_DIAGONAL_NAME assert data_format in [None, "NHWC"] @@ -785,7 +830,7 @@ class LayerCollection(object): rate=rate, data_format=data_format), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self._add_uses(params, 1) @@ -803,7 +848,7 @@ class LayerCollection(object): reuse=VARIABLE_SCOPE): """Register a call to tf.nn.separable_conv2d(). - Note: This requires access to intermediate outputs betwee depthwise and + Note: This requires access to intermediate outputs between depthwise and pointwise convolutions. Args: @@ -824,10 +869,14 @@ class LayerCollection(object): rate: None or List of ints of length 2. Dilation rate of depthwise conv2d kernel in spatial dimensions. data_format: str or None. Format of data. - approx: None or str. Must be "kron" if non-None. - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. @@ -864,34 +913,32 @@ class LayerCollection(object): Args: params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of "full" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds 'batch_size' to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_generic_approximation - - if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_minibatch(batch_size) + block.register_additional_tower(batch_size) self._add_uses(params, float("inf")) def register_fully_connected_multi(self, params, inputs, outputs, - approx=None, reuse=VARIABLE_SCOPE): + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): """Register fully connected layers with shared parameters. This can handle general fully-connected layers with shared parameters, but @@ -902,41 +949,195 @@ class LayerCollection(object): params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. - inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs - to layer. In the case of RNNs, one Tensor per time step. - outputs: A list of tensors, the same length as 'inputs', each of shape - [batch_size, output_size]. Outputs produced by layer. In the case of - RNNs, one Tensor per time step. - approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in 'inputs'. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") Raises: ValueError: For improper value to 'approx'. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_multi_approximation - has_bias = isinstance(params, (tuple, list)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). - if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] - - block = self.register_block(params, block_type(self, has_bias=has_bias), + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) - self._add_uses(params, len(inputs)) + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in 'inputs'. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) # TODO(b/74108452): change the loss registration functions names to refer # to "loss functions" instead of distributions. Following naming convention # of the loss function classes themselves. + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in 'inputs'. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds inputs and outputs as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word 'use' here has a completely different meaning to "use in the graph" + as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + def register_categorical_predictive_distribution(self, logits, seed=None, @@ -955,9 +1156,10 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. + reuse: bool or str. If True, this adds 'logits' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, seed=seed) @@ -988,9 +1190,10 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. + reuse: bool or str. If True, this adds 'mean' and 'var' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, seed=seed) @@ -1016,9 +1219,10 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. + reuse: bool or str. If True, this adds 'logits' as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, seed=seed) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 083da768ec97aca3e63995491bb579835bb5377f..843aeef7d82df064b757ab4618f2b0ccbbec4cbe 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import warnings - # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -53,8 +52,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): estimation_mode="gradients", colocate_gradients_with_ops=True, batch_size=None, - cov_devices=None, - inv_devices=None): + placement_strategy=None, + **kwargs): """Initializes the KFAC optimizer with the given settings. Args: @@ -96,14 +95,11 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): (Default: True) batch_size: The size of the mini-batch. Only needed when momentum_type == 'qmodel' or when automatic adjustment is used. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. Only used - with (soon-to-be-depcrecated "convenience" properties). - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. Only used - with (soon-to-be-depcrecated "convenience" properties). + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passesd to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. @@ -123,8 +119,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._layers = layer_collection self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops - self._cov_devices = cov_devices - self._inv_devices = inv_devices # The below paramaters are required only if damping needs to be adapated. # These parameters can be set by calling @@ -164,16 +158,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._momentum_type = momentum_type self._norm_constraint = norm_constraint self._batch_size = batch_size + self._placement_strategy = placement_strategy with variable_scope.variable_scope(name): - self._fisher_est = est.FisherEstimator( - self._variables, - self._cov_ema_decay, - self.damping, - self._layers, + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, exps=(-1,), estimation_mode=self._estimation_mode, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) @@ -236,6 +233,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._damping = variable_scope.get_variable( "damping", initializer=self._damping_constant, trainable=False) + @property + def variables(self): + return self._variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + @property def cov_update_thunks(self): self._maybe_make_and_save_everything() @@ -266,37 +278,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._maybe_make_and_save_everything() return self._inv_update_op - @property - def variables(self): - return self._variables - - @property - def damping(self): - if self._damping: - return self._damping - else: - return self._damping_constant - - @property - def damping_adaptation_interval(self): - return self._damping_adaptation_interval - def _maybe_make_and_save_everything(self): if not self._fisher_est.made_vars(): warnings.warn("These convenience properties will be depcrecated soon. " "Please use explicit op/thunk creation methods instead " - "(e.g. make_ops_and_vars_round_robin, etc).", + "(e.g. make_ops_and_vars, etc).", DeprecationWarning) (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, self._inv_update_op, self._cov_update_thunks, - self._inv_update_thunks) = self.make_ops_and_vars_round_robin( - cov_devices=self._cov_devices, - inv_devices=self._inv_devices) + self._inv_update_thunks) = self.make_ops_and_vars() def make_ops_and_vars(self): - """Make ops and vars with no specific device placement. + """Make ops and vars with device placement `self._placement_strategy`. - See make_ops_and_vars_round_robin for details. + See `FisherEstimator.make_ops_and_vars` for details. Returns: cov_update_ops: List of ops that compute the cov updates. Corresponds @@ -307,77 +302,11 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): cov_update_op: cov_update_ops grouped into a single op. inv_update_op: inv_update_ops grouped into a single op. """ - with variable_scope.variable_scope(self.get_name()): - return self._fisher_est.make_ops_and_vars() - - def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. + return self._fisher_est.make_ops_and_vars(scope=self.get_name()) - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - with variable_scope.variable_scope(self.get_name()): - return self._fisher_est.make_ops_and_vars_round_robin( - cov_devices=cov_devices, inv_devices=inv_devices) - - def make_vars_and_create_op_thunks_round_robin(self, - cov_devices=None, - inv_devices=None): - """Make vars and create op thunks w/ a round-robin device placement strat. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. Returns: cov_update_thunks: List of cov update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. @@ -385,10 +314,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): the list of factors given by the "factors" property. """ scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.make_vars_and_create_op_thunks_round_robin( - scope=scope, cov_devices=cov_devices, inv_devices=inv_devices) + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) - def ops_and_vars_thunks(self): + def create_ops_and_vars_thunks(self): """Create thunks that make the ops and vars on demand. This function returns 4 lists of thunks: cov_variable_thunks, @@ -413,7 +341,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): inv_update_thunks: A list of thunks that make the inv update ops. """ scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.ops_and_vars_thunks(scope=scope) + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def minimize(self, *args, **kwargs): # Should this variable scope encompass everything below? Or will the super- @@ -462,7 +390,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): An `Operation` that applies the specified gradients. """ self._maybe_make_and_save_everything() - # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. @@ -618,7 +545,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # compute the matrix-vector products with the transposed Fisher factor fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( self._batch_size, dtype=fft_precon_grads[0].dtype) @@ -802,7 +728,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] - # TODO(b/73448937): Move all update damping code to a separate class/function. def _update_damping(self, prev_batch, global_step): """Adapts damping parameter. Check KFAC (Section 6.5) for the details. diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000000000000000000000000000000000..bf12dbaa9adbaa4af1511034aef0b5ab59d53e26 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + *args: + **kwargs: + + """ + super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a round-robin device placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `None` then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + (cov_update_thunks, + inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) + cov_update_ops = [thunk() for thunk in cov_update_thunks] + inv_update_ops = [thunk() for thunk in inv_update_thunks] + + scope = self.name if scope is None else scope + with variable_scope.variable_scope(scope): + cov_update_op = control_flow_ops.group(cov_update_ops, + name="cov_update_op") + inv_update_op = control_flow_ops.group(inv_update_ops, + name="inv_update_op") + + return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, + cov_update_thunks, inv_update_thunks) + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement strat. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index af26f5e56bf9bb22cc9bc2b409209d027477ed89..b6f42815e79fa5eb9c6a2aa9f99ac3ec5a70ad0a 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -649,9 +649,6 @@ class PartitionedTensor(object): def dtype(self): return self.tensors[0].dtype - def devices(self): - return set(tensor.device for tensor in self.tensors) - def __str__(self): return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) @@ -659,6 +656,17 @@ class PartitionedTensor(object): def __hash__(self): return hash(tuple(self.tensors)) + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + def as_tensor(self, dtype=None, name=None, as_ref=False): with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): assert not as_ref @@ -670,6 +678,15 @@ class PartitionedTensor(object): self._concats[result.device] = result return self._concats[result.device] + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + ops.register_tensor_conversion_function( PartitionedTensor, diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 544065dac6a10094a376c18e84521b1a26401cdd..c8812d4b23f94102d093db878a709b090a3318d6 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -214,14 +214,3 @@ py_test( "//tensorflow/python:math_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index cc7bbabf210ded9a31eb789fa8b94e8bde62ea43..d5b3b279a1b7327602790c0260349cb0c758aa86 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -392,15 +392,3 @@ py_test( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 337c9e06b870b2cca53fcdbf3d94225660e193c4..00f03a111ae8be7f49761ef5fb5a82810bcca182 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -104,6 +104,7 @@ See the @{$python/contrib.layers} guide. @@infer_real_valued_columns @@sequence_input_from_feature_columns +@@group_norm @@instance_norm """ @@ -122,6 +123,7 @@ _allowed_symbols = ['bias_add', 'conv3d', 'elu', 'feature_column', + 'group_norm', 'instance_norm', 'legacy_fully_connected', 'legacy_linear', diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD index e407a9ce015603094c7bbab72856403e2f0eb1a1..7aae09ff3e9995b2d92b05211b3bf8a94a26ff43 100644 --- a/tensorflow/contrib/layers/kernels/BUILD +++ b/tensorflow/contrib/layers/kernels/BUILD @@ -18,14 +18,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 350bcb3bca11b4cad18ce863ab1496076477aa3c..10d7f6d076b4b4c6578d7adcffc4e9cc44d77ac6 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -3045,16 +3045,16 @@ def legacy_fully_connected(x, `activation_fn` is `None`, the result of `y = w * x + b` is returned. - If `x` has shape [\\\(\\text{dim}_0, \\text{dim}_1, ..., \\text{dim}_n\\\)] - with more than 2 dimensions (\\\(n > 1\\\)), then we repeat the matrix + If `x` has shape [\\(\text{dim}_0, \text{dim}_1, ..., \text{dim}_n\\)] + with more than 2 dimensions (\\(n > 1\\)), then we repeat the matrix multiply along the first dimensions. The result r is a tensor of shape - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`], - where \\\( r_{i_0, ..., i_{n-1}, k} = - \\sum_{0 \\leq j < \\text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\\). + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`], + where \\( r_{i_0, ..., i_{n-1}, k} = + \sum_{0 \leq j < \text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\). This is accomplished by reshaping `x` to 2-D - [\\\(\\text{dim}_0 \\cdot ... \\cdot \\text{dim}_{n-1}, \\text{dim}_n\\\)] + [\\(\text{dim}_0 \cdot ... \cdot \text{dim}_{n-1}, \text{dim}_n\\)] before the matrix multiply and afterwards reshaping it to - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`]. + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`]. This op creates `w` and optionally `b`. Bias (`b`) can be disabled by setting `bias_init` to `None`. diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index e7d4080ff769327cc74b6629a7705ddfa552169b..c807ab0f2e5c8ac3ec2ae1d84a5b36b5f4ba76a4 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope __all__ = [ + 'group_norm', 'instance_norm', ] @@ -158,3 +160,196 @@ def instance_norm(inputs, if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + +@add_arg_scope +def group_norm(inputs, + groups=32, + channels_axis=-1, + reduction_axes=(-3, -2), + center=True, + scale=True, + epsilon=1e-6, + activation_fn=None, + param_initializers=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Functional interface for the group normalization layer. + + Reference: https://arxiv.org/abs/1803.08494. + + "Group Normalization", Yuxin Wu, Kaiming He + + Args: + inputs: A Tensor with at least 2 dimensions one which is channels. All + shape dimensions must be fully defined. + groups: Integer. Divide the channels into this number of groups over which + normalization statistics are computed. This number must be commensurate + with the number of channels in `inputs`. + channels_axis: An integer. Specifies index of channels axis which will be + broken into `groups`, each of which whose statistics will be computed + across. Must be mutually exclusive with `reduction_axes`. Preferred usage + is to specify negative integers to be agnostic as to whether a batch + dimension is included. + reduction_axes: Tuple of integers. Specifies dimensions over which + statistics will be accumulated. Must be mutually exclusive with + `channels_axis`. Statistics will not be accumulated across axes not + specified in `reduction_axes` nor `channel_axis`. Preferred usage is to + specify negative integers to be agnostic to whether a batch dimension is + included. + + Some sample usage cases: + NHWC format: channels_axis=-1, reduction_axes=[-3, -2] + NCHW format: channels_axis=-3, reduction_axes=[-2, -1] + + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + epsilon: Small float added to variance to avoid dividing by zero. + activation_fn: Activation function, default set to None to skip it and + maintain a linear activation. + param_initializers: Optional initializers for beta, gamma, moving mean and + moving variance. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional collections for the variables. + outputs_collections: Collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + scope: Optional scope for `variable_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: If the rank of `inputs` is undefined. + ValueError: If rank or channels dimension of `inputs` is undefined. + ValueError: If number of groups is not commensurate with number of channels. + ValueError: If reduction_axes or channels_axis are out of bounds. + ValueError: If reduction_axes are not mutually exclusive with channels_axis. + """ + # TODO(shlens): Support partially defined shapes for the inputs. + inputs = ops.convert_to_tensor(inputs) + original_shape = inputs.shape + + if inputs.shape.ndims is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + if channels_axis > (inputs.shape.ndims - 1): + raise ValueError('Axis is out of bounds.') + + # Standardize the channels_axis to be positive and identify # of channels. + if channels_axis < 0: + channels_axis = inputs.shape.ndims + channels_axis + channels = inputs.shape[channels_axis].value + + if channels is None: + raise ValueError('Inputs %s has undefined channel dimension: %d.' % ( + inputs.name, channels_axis)) + + # Standardize the reduction_axes to be positive. + reduction_axes = list(reduction_axes) + for i in range(len(reduction_axes)): + if reduction_axes[i] < 0: + reduction_axes[i] += inputs.shape.ndims + + for a in reduction_axes: + if a > inputs.shape.ndims: + raise ValueError('Axis is out of bounds.') + if inputs.shape[a].value is None: + raise ValueError('Inputs %s has undefined dimensions %d.' % ( + inputs.name, a)) + if channels_axis == a: + raise ValueError('reduction_axis must be mutually exclusive ' + 'with channels_axis') + if groups > channels: + raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) + if channels % groups != 0: + raise ValueError('%d channels is not commensurate with %d groups.' % + (channels, groups)) + + # Determine axes before channels. Some examples of common image formats: + # 'NCHW': before = [N], after = [HW] + # 'NHWC': before = [NHW], after = [] + axes_before_channels = inputs.shape.as_list()[:channels_axis] + axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + + # Manually broadcast the parameters to conform to the number of groups. + params_shape_broadcast = ([1] * len(axes_before_channels) + + [groups, channels // groups] + + [1] * len(axes_after_channels)) + + # Reshape the input by the group within the channel dimension. + inputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + inputs = array_ops.reshape(inputs, inputs_shape) + + # Determine the dimensions across which moments are calculated. + moments_axes = [channels_axis + 1] + for a in reduction_axes: + if a > channels_axis: + moments_axes.append(a + 1) + else: + moments_axes.append(a) + + with variable_scope.variable_scope( + scope, 'GroupNorm', [inputs], reuse=reuse) as sc: + # Note that the params_shape is the number of channels always. + params_shape = [channels] + + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + dtype = inputs.dtype.base_dtype + if param_initializers is None: + param_initializers = {} + if center: + beta_collections = utils.get_variable_collections( + variables_collections, 'beta') + beta_initializer = param_initializers.get( + 'beta', init_ops.zeros_initializer()) + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) + beta = array_ops.reshape(beta, params_shape_broadcast) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get( + 'gamma', init_ops.ones_initializer()) + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + gamma = array_ops.reshape(gamma, params_shape_broadcast) + + # Calculate the moments. + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + + # Compute normalization. + # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor + # appropriately so that this operation may be faster. + gain = math_ops.rsqrt(variance + epsilon) + offset = -mean * gain + if gamma is not None: + gain *= gamma + offset *= gamma + if beta is not None: + offset += beta + outputs = inputs * gain + offset + + # Collapse the groups into the channel dimension. + outputs = array_ops.reshape(outputs, original_shape) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index 5cff1bf0ebb2fe8bc6933de882ecd47a9edf0f94..b6e96350db92baf4770683273be7e5dde73dbcec 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase): def testOutputBigInput5DNCHW(self): self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3) + +class GroupNormTest(test.TestCase): + + def testInvalidGroupSize(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10)) + with self.assertRaisesRegexp(ValueError, + 'Invalid groups 10 for 2 channels.'): + normalization.group_norm(inputs, groups=10, + reduction_axes=[-2, -1], channels_axis=-3) + + def testBadCommensurateGroup(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10)) + with self.assertRaisesRegexp(ValueError, + '4 channels is not commensurate with ' + '3 groups.'): + normalization.group_norm(inputs, groups=3, + reduction_axes=[-2, -1], channels_axis=-3) + + def testAxisIsBad(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5)) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, channels_axis=5) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, reduction_axes=[1, 5]) + + def testNotMutuallyExclusiveAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32)) + # Specify axis with negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2]) + # Specify axis with positive values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3]) + # Specify axis with mixed positive and negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2]) + + def testUnknownShape(self): + inputs = array_ops.placeholder(dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'undefined rank'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedReductionAxes(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4)) + with self.assertRaisesRegexp(ValueError, 'undefined dimensions'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedChannelsAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None)) + with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'): + normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2]) + + def testCreateOp(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) + output = normalization.group_norm(images, groups=groups, channels_axis=-1, + reduction_axes=[-3, -2]) + print('name: ', output.op.name) + self.assertListEqual([5, height, width, 2*groups], output.shape.as_list()) + + def testCreateOpFloat64(self): + height, width, groups = 3, 3, 5 + images = random_ops.random_uniform( + (5, height, width, 4*groups), dtype=dtypes.float64, seed=1) + output = normalization.group_norm(images, groups=groups) + self.assertEqual(dtypes.float64, output.dtype) + self.assertListEqual([5, height, width, 4*groups], output.shape.as_list()) + + def testCreateOpNoScaleCenter(self): + height, width, groups = 3, 3, 7 + images = random_ops.random_uniform( + (5, height, width, 3*groups), dtype=dtypes.float32, seed=1) + output = normalization.group_norm(images, groups=groups, center=False, + scale=False) + self.assertListEqual([5, height, width, 3*groups], output.shape.as_list()) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) + + def testCreateVariables_NHWC(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 8), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-1, reduction_axes=(-3, -2), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testCreateVariables_NCHW(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, 2*groups, height, width), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-3, reduction_axes=(-2, -1), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testReuseVariables(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 4), seed=1) + normalization.group_norm(images, groups=2, scale=True, scope='IN') + normalization.group_norm(images, groups=2, scale=True, scope='IN', + reuse=True) + beta = contrib_variables.get_variables_by_name('beta') + gamma = contrib_variables.get_variables_by_name('gamma') + self.assertEqual(1, len(beta)) + self.assertEqual(1, len(gamma)) + + def testValueCorrectWithReuseVars(self): + height, width = 3, 3 + image_shape = (10, height, width, 4) + images = random_ops.random_uniform(image_shape, seed=1) + output_train = normalization.group_norm(images, groups=2, scope='IN') + output_eval = normalization.group_norm(images, groups=2, scope='IN', + reuse=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + # output_train and output_eval should be the same. + train_np, eval_np = sess.run([output_train, output_eval]) + self.assertAllClose(train_np, eval_np) + + def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, + groups=2, tol=1e-2): + # Select the axis for the channel and the dimensions along which statistics + # are accumulated. + if channels_axis < 0: + channels_axis += len(input_shape) + reduced_axes = [channels_axis + 1] + for a in reduction_axes: + if a < 0: + a += len(input_shape) + if a < channels_axis: + reduced_axes.append(a) + else: + reduced_axes.append(a+1) + reduced_axes = tuple(reduced_axes) + + # Calculate the final shape for the output Tensor. + axes_before_channels = input_shape[:channels_axis] + axes_after_channels = input_shape[channels_axis+1:] + channels = input_shape[channels_axis] + outputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + + # Calculate the final shape for the output statistics. + reduced_shape = [] + for i, a in enumerate(outputs_shape): + if i not in reduced_axes: + reduced_shape.append(a) + + for mu in (0.0, 1e2): + for sigma in (1.0, 0.1): + # Determine shape of Tensor after normalization. + expected_mean = np.zeros(reduced_shape) + expected_var = np.ones(reduced_shape) + + inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + output_op = normalization.group_norm( + inputs, groups=groups, center=False, scale=False, + channels_axis=channels_axis, + reduction_axes=reduction_axes) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outputs = sess.run(output_op) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(outputs).any()) + + outputs = np.reshape(outputs, outputs_shape) + mean = np.mean(outputs, axis=reduced_axes) + var = np.var(outputs, axis=reduced_axes) + # The mean and variance of each example should be close to 0 and 1 + # respectively. + self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol) + self.assertAllClose(expected_var, var, rtol=tol, atol=tol) + + def testOutputSmallInput4D_NHWC(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput3D_NHWC(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput4D_NCHW(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputSmallInput3D_NCHW(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputBigInput4D_NHWC(self): + self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], + groups=1) + + def testOutputBigInput4D_NCHW(self): + self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], + groups=4) + + def testOutputSmallInput2D_NC(self): + self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + + def testOutputSmallInput5D_NCXXX(self): + self.doOutputTest([10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 123275e1fde047cd3772528641b2e3b09742fbdc..e49589ddf627aa456496cebb2d0fc72fcdad710f 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -29,14 +29,17 @@ from __future__ import print_function import functools import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -46,6 +49,7 @@ from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") +_USE_DEFAULT = "__rev_block_lib_default" def _acc_grads(*lists_of_grads): @@ -219,7 +223,13 @@ class RevBlock(base.Layer): def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" + # Inputs have passed through an Identity. Recover the original Tensors to + # be able to match up side inputs. + assert [u"Identity"] == list(set([x.op.type for x in inputs])) + inputs = [x.op.inputs[0] for x in inputs] side_inputs = inputs[2:] + del inputs + f_side_idxs = [None] * len(self.f_side_input) g_side_idxs = [None] * len(self.g_side_input) assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) @@ -405,12 +415,36 @@ def rev_block(x1, return block.forward(x1, x2) -def recompute_grad(fn): +def enable_with_args(dec): + """A decorator for decorators to enable their usage with or without args.""" + + @functools.wraps(dec) + def new_dec(*args, **kwargs): + if len(args) == 1 and not kwargs and callable(args[0]): + # Used as decorator without args + fn = args[0] + return dec(fn) + else: + return lambda fn: dec(fn, *args, **kwargs) + + return new_dec + + +@enable_with_args +def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. + use_data_dep: `bool`, if `True` will use a dummy data dependency to force + the recompute to happen. If `False` will use a control dependency. By + default will be `True` if in an XLA context and `False` otherwise. XLA + ignores control dependencies and so this data dependency is necessary. + tupleize_grads: `bool`, if `True` will use control dependencies to ensure + that all gradients are produced before any are consumed by downstream ops. + If `use_data_dep` is also `True`, will use a data dependency instead of + a control dependency. Returns: A wrapped fn that is identical to fn when called, but its activations will @@ -420,13 +454,25 @@ def recompute_grad(fn): @functools.wraps(fn) def wrapped(*args): - return _recompute_grad(fn, args) + return _recompute_grad( + fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads) return wrapped -def _recompute_grad(fn, args): +def _is_on_tpu(): + ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(ctxt) is not None + + +def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + for arg in args: + if not isinstance(arg, framework_ops.Tensor): + raise ValueError("All inputs to function must be Tensors") + use_data_dep_ = use_data_dep + if use_data_dep_ == _USE_DEFAULT: + use_data_dep_ = _is_on_tpu() cached_vs = [] cached_arg_scope = [] @@ -436,6 +482,8 @@ def _recompute_grad(fn, args): del outputs # Recompute outputs with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(cached_arg_scope[0]): with variable_scope.variable_scope(cached_vs[0], reuse=True): outputs = fn(*inputs) @@ -444,6 +492,13 @@ def _recompute_grad(fn, args): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars @@ -532,7 +587,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): get_vars_fn = ( vs.global_variables if use_global_vars else vs.trainable_variables) len_before_vars = len(get_vars_fn()) - inputs = list(inputs) + inputs = [array_ops.identity(x) for x in inputs] outputs = fn(*inputs) train_vars = get_vars_fn()[len_before_vars:] @@ -581,3 +636,48 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): flat_inputs = nest.flatten(defun_inputs) id_out = identity(*flat_inputs) return id_out + + +def _force_data_dependency(first_compute, then_compute): + """Force all of `then_compute` to depend on all of `first_compute`. + + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. + + Args: + first_compute: `list`. These will be made to run before the + `Tensor`s `then_compute`. + then_compute: `list`. These will run after all the `Tensor`s in + `first_compute`. + + Returns: + `list`, same length as `then_compute`. + + Raises: + ValueError: if ranks are unknown or types are not floating. + """ + + def _first_element(x): + if x.get_shape().ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.get_shape().ndims + begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32) + size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32) + return array_ops.reshape(array_ops.slice(x, begin, size), []) + + first_compute_sum = math_ops.add_n( + [_first_element(x) for x in first_compute if x is not None]) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + epsilon = np.finfo(dtype.as_numpy_dtype).tiny + zero = array_ops.stop_gradient(epsilon * first_compute_sum) + + return [ + array_ops.identity(x) + zero if x is not None else None + for x in then_compute + ] + + +def _tuple_with_data_dep(tensors): + return _force_data_dependency(tensors, tensors) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index cbcbcd75114a522b95631e4e7e95c1641b0a9987..d1ad4e8c98de3e5c5ac212d55cc93707ba9c01cc 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase): y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): - self.assertAllClose(g1, g2) + self.assertAllClose(g1, g2, rtol=1e-5) def testRevBlock(self): self._testRevBlock() @@ -255,25 +255,54 @@ class RecomputeTest(test.TestCase): def fn_recompute(x): return fn(x) + @rev_block_lib.recompute_grad(use_data_dep=True) + def fn_use_data_dep(x): + return fn(x) + + @rev_block_lib.recompute_grad(tupleize_grads=True) + def fn_tupleize(x): + return fn(x) + + @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True) + def fn_both(x): + return fn(x) + x = random_ops.random_uniform((3, 1, 3)) - recompute_vars = None - with variable_scope.variable_scope("recompute") as vs: - out1 = math_ops.reduce_sum(fn_recompute(x)) - recompute_vars = vs.trainable_variables() - reg_vars = None - with variable_scope.variable_scope("regular") as vs: - out2 = math_ops.reduce_sum(fn(x)) - reg_vars = vs.trainable_variables() - - grad1 = gradients_impl.gradients(out1, recompute_vars) - grad2 = gradients_impl.gradients(out2, reg_vars) + + names_and_fns = [ + ("recompute", fn_recompute), + ("regular", fn), + ("use_data_dep", fn_use_data_dep), + ("tupleize", fn_tupleize), + ("tuple_and_data_dep", fn_both), + ] + outputs_and_vars = [] + for name, wrapped_fn in names_and_fns: + with variable_scope.variable_scope(name) as vs: + out = math_ops.reduce_sum(wrapped_fn(x)) + outputs_and_vars.append((out, vs.trainable_variables())) + + all_grads = [] + for out, scope_vars in outputs_and_vars: + all_grads.append(gradients_impl.gradients(out, scope_vars)) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - outs = sess.run([out1, out2, grad1, grad2]) - self.assertAllClose(outs[0], outs[1]) - for g1, g2 in zip(outs[2], outs[3]): - self.assertAllClose(g1, g2) + outputs = list(zip(*outputs_and_vars))[0] + outs, all_grads_val = sess.run([outputs, all_grads]) + + # All outputs are the same + current = outs[0] + for out in outs[1:]: + self.assertAllClose(current, out) + current = out + + # All gradients are the same + for grads in zip(all_grads_val): + current = grads[0] + for g in grads[1:]: + self.assertAllClose(current, g) + current = g class FnWithCustomGradTest(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index b05f5eeaeee8fb927970b608f65495f33d63f764..d665fc9335cf22cdfa1e7330ab67003042502515 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -229,6 +229,7 @@ py_test( size = "small", srcs = ["python/learn/monitors_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip_gpu"], # b/74437598 deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -878,15 +879,3 @@ py_binary( "//tensorflow/python:platform", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD index 8bf372841d04dc9e1339925474801d5aa3af4ccd..2c7215bba3816ff3762e5b7927f650d1c9cbf617 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/BUILD +++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD @@ -44,18 +44,6 @@ py_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_test( name = "base_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index 3b5c9b97c08a388e1f35249967b6cab26861f100..4676eedb206147d178c6a652aa7c2cb48ef888c0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -139,15 +139,48 @@ def retry(initial_delay, Args: initial_delay: the initial delay. + max_delay: the maximum delay allowed (actual max is + max_delay * (1 + jitter). factor: each subsequent retry, the delay is multiplied by this value. (must be >= 1). jitter: to avoid lockstep, the returned delay is multiplied by a random number between (1-jitter) and (1+jitter). To add a 20% jitter, set jitter = 0.2. Must be < 1. + is_retriable: (optional) a function that takes an Exception as an argument + and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. + """ + return _internal_retry( + initial_delay=initial_delay, + max_delay=max_delay, + factor=factor, + jitter=jitter, + is_retriable=is_retriable) + + +def _internal_retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): + """Simple decorator for wrapping retriable functions, for internal use only. + + Args: + initial_delay: the initial delay. max_delay: the maximum delay allowed (actual max is max_delay * (1 + jitter). + factor: each subsequent retry, the delay is multiplied by this value. + (must be >= 1). + jitter: to avoid lockstep, the returned delay is multiplied by a random + number between (1-jitter) and (1+jitter). To add a 20% jitter, set + jitter = 0.2. Must be < 1. is_retriable: (optional) a function that takes an Exception as an argument and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. """ if factor < 1: raise ValueError('factor must be >= 1; was %f' % (factor,)) @@ -195,7 +228,7 @@ def _is_retriable(e): @deprecated(None, 'Please use urllib or similar directly.') -@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) +@_internal_retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) def urlretrieve_with_retry(url, filename=None): return urllib.request.urlretrieve(url, filename) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 64d7ecc68e7abb1d36a3eb098fedd8184d6e9d77..70b70af98c51dcb991c19152607272673953ee2a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -243,8 +243,8 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" - with variable_scope.variable_op_scope( - features.values(), parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 1d161093de01ef838d0c75ec9a39574c7529bd57..8c85c431be69caaca6872111896b9487faf9e679 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -290,8 +290,15 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. """ - super(RunConfig, self).__init__( - master=master, evaluation_master=evaluation_master) + # Neither parent class calls super().__init__(), so here we have to + # manually call their __init__() methods. + ClusterConfig.__init__( + self, master=master, evaluation_master=evaluation_master) + # For too long this code didn't call: + # core_run_config.RunConfig.__init__(self) + # so instead of breaking compatibility with that assumption, we + # just manually initialize this field: + self._train_distribute = None gpu_options = config_pb2.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction) diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD index 1fa55132b1fc0cd3367ca2eb331b6870edc30c3b..8c2c4fd29c0502d4199f27a65e4827b2db973c3d 100644 --- a/tensorflow/contrib/legacy_seq2seq/BUILD +++ b/tensorflow/contrib/legacy_seq2seq/BUILD @@ -60,15 +60,3 @@ cuda_py_tests( ], tags = ["noasan"], # times out b/63678675 ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD index df96402a4ffd51840f77d58d8066487030362340..4dccb9be7cd2e603edcf10c020cc0ee1675f518a 100644 --- a/tensorflow/contrib/libsvm/BUILD +++ b/tensorflow/contrib/libsvm/BUILD @@ -88,15 +88,3 @@ tf_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index d4f2e7063184d962f4654cf8df4ab966c1941139..a7812f74d1e69276a4bba597b41e442bc4dbbc4a 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -58,16 +58,6 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], + shard_count = 4, + tags = ["noasan"], ) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py index 5d7a99664d38eca035bd5a86710050bce4b22c1e..9d3af66c92b59dd030d4b2a829ab733eec6cf0c1 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_util @@ -137,8 +138,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator): meaning the quadratic form `x^H A x` has positive real part for all nonzero `x`. Note that we do not require the operator to be self-adjoint to be positive-definite. See: - https://en.wikipedia.org/wiki/Positive-definite_matrix\ - #Extension_for_non_symmetric_matrices + https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices is_square: Expect that this operator acts like square [batch] matrices. This is true by default, and will raise a `ValueError` otherwise. name: A name for this `LinearOperator`. Default is the individual @@ -333,6 +333,18 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator): mat.set_shape(self.shape) return mat + def _assert_non_singular(self): + return control_flow_ops.group([ + operator.assert_non_singular() for operator in self.operators]) + + def _assert_self_adjoint(self): + return control_flow_ops.group([ + operator.assert_self_adjoint() for operator in self.operators]) + + def _assert_positive_definite(self): + return control_flow_ops.group([ + operator.assert_positive_definite() for operator in self.operators]) + def _split_input_into_blocks(self, x, axis=-1): """Split `x` into blocks matching `operators`'s `domain_dimension`. diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index cea3627ed565f0de86d8d9bb6b45c4b19c5b5558..5b89c6cef9fa9fdef7c26ddee1efa03f3056d881 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -138,14 +138,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 05794a42c5f2d0eece6adab36fb5610078cece31..d4e54c82f988e0adcd16aad29702ee9f8b16aea3 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -140,8 +140,8 @@ def sdca_model_fn(features, labels, mode, params, config=None): parent_scope = "linear" - with variable_scope.variable_op_scope(features.values(), - parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 5cfbb544b73991195a7bba9528ee9550104f3d78..9c4533079c72f5ed68c6f45582fb1cecaa3a3679 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -89,6 +89,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], + deps = [":context"], ) cc_library( @@ -133,10 +134,10 @@ cc_library( ":schema_fbs_version", ":simple_memory_arena", ":util", + "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/core:lib_platform", ], ) @@ -170,6 +171,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", @@ -270,18 +272,3 @@ cc_test( # ], # }), #) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "downloads", - "examples", - "gen", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 5194f015b5b84189e3a8caf5fb0bc0204deb7bb2..a676b705f143b393c7e5bfa9e40d23f9adb68dcc 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -1,235 +1,8 @@ # TensorFlow Lite -TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. -TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded +devices. It enables low-latency inference of on-device machine learning models +with a small binary size and fast performance supporting hardware acceleration. -![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with an Android Demo App - -This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo. - -There are 3 ways to get the demo app to your device - - Download the prebuilt binary or - - Use Android Studio to build the application or - - Download the source code for TensorFlow Lite and the demo and build it using bazel - -## Description -In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. - -## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary -[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) - -Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. - -## Building in Android Studio using TensorFlow Lite AAR from JCenter -The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. - - - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). - - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - - Click through installing all the Gradle extensions it requests. - - Either - - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/` - - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) - - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory - - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from - `classifier = new ImageClassifierQuantizedMobileNet(getActivity());` - to - `classifier = new ImageClassifierFloatInception(getActivity());` - - Build and run the demo app - -## Building TensorFlow Lite and the demo app from source - -### Clone the TensorFlow repo -- git clone - [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) - -### Install Bazel -If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) - -NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. - -### Install Android NDK and SDK -Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. - - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) - - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). - - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). - - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` - -``` -android_sdk_repository ( - name = "androidsdk", - api_level = 23, - build_tools_version = "23.0.2", - path = "/home/xxxx/android-sdk-linux/", -) - -android_ndk_repository( - name = "androidndk", - path = "/home/xxxx/android-ndk-r10e/", - api_level = 19, -) -``` - -Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). - -### Build the source code -Run bazel with the following command to build the demo. - -Build the demo app: - -``` -bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo -``` - -### Note - -Currently, we only support building the Android demo app within a Python 2 -environment (due to a Bazel bug). - -### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (299 * 299 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. - -# iOS Demo App - -Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). - -This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: - -1. Run `tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. -1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. -1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. -1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. - -# TensorFlow Lite Quick Start - -## Step 1. Decide which GraphDef to use - Depending on the use case, the developer may choose to use one of the popular - open-sourced models such as InceptionV3 or MobileNets, re-train these models - with their own custom data set or even build their own custom model. - -### Using a pre-trained model - -[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. - -[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. - -[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. - -These pre-trained models can be downloaded from [here](g3doc/models.md). - -### Retrain Inception-V3 or MobileNet for a custom data set -The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. - -The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. - - -### Train a custom model -A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. - -TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This -set will continue to expand in future releases of Tensorflow Lite. - - -## Step 2. Model format conversion - -The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. - -A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. - -Since we employ several formats, the following definitions may be useful: - - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. - - - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. - - - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. - - - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. - - - TensorFlow lite model (.tflite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. - -### Freeze Graph -To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. - -The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). - -Graph freezing can be done using the command below (and modifying the arguments appropriately) - -``` -bazel build tensorflow/python/tools:freeze_graph - -bazel-bin/tensorflow/python/tools/freeze_graph\ - --input_graph=/tmp/mobilenet_v1_224.pb \ - --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ - --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ - --output_node_names=MobileNet/Predictions/Reshape_1 -``` - -The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with -graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). - -This frozen Graphdef is now ready to be converted to flatbuffer format (.tflite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. - -Here is a sample command line to convert the frozen Graphdef to '.tflite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. -(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). - -``` -bazel build tensorflow/contrib/lite/toco:toco - -bazel-bin/tensorflow/contrib/lite/toco/toco \ - --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ - --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ - --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT \ - --input_type=FLOAT --input_arrays=input \ - --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 -``` - -- The input_file argument should point to the frozen GraphDef file that holds the model architecture. -- The output_file argument should point to where the TensorFlow Lite model file should be generated. -- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. -- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. - -Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, - -```python -import tensorflow as tf - -img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) -out = tf.identity(val, name="out") -with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("converteds_model.tflite", "wb").write(tflite_model) - -``` -For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). - -You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). - -If you would like to see a visual description of your TensorFlow Lite model after conversion, you can use tensorflow/contrib/lite/tools/visualize.py by running -```sh -bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html -``` -and then visualize the resulting HTML file in a browser. - -## Step 3. Use the TensorFlow Lite model for inference in a mobile app - -After completion of Step 2 the developer should have a .tflite model. - -### For Android -Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). - -The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). - -Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). - -### For iOS -Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. - -## Core ML support - -Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). +See the documentation: https://www.tensorflow.org/mobile/tflite/ +Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite) diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index f84b3dad9550e789237c8e45971002c7d336b9d3..e9d0fbc5a9b5aec06e28da8757466b25f40da2f5 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { -class AllocationInfo; +struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 4a9023ff33de15dd384531d51e39de4ffeecdb8b..9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,11 +19,16 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a lipo \ tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5fc8954743e5b3b458e5c2004f4378cbad6056c0..2b6c24768c0f35b91d0dabf8a5723e73f040cc3b 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/context.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -174,6 +176,11 @@ typedef struct { int block_size; } TfLiteSpaceToDepthParams; +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + typedef enum { kTfLiteCombinerTypeSum = 0, kTfLiteCombinerTypeMean = 1, diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index ea3ae3489ecf07b22a02829c5235ad59264496af..17b791e4e2f38d9a1108d35d1298445a1c370727 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -24,8 +24,7 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real builtin -// ops. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. typedef enum { kTfLiteBuiltinAdd = 0, kTfLiteBuiltinAveragePool2d = 1, @@ -79,6 +78,8 @@ typedef enum { kTfLiteBuiltinDelegate = 51, kTfLiteBuiltinBidirectionalSequenceLstm = 52, kTfLiteBuiltinCast = 53, + kTfLiteBuiltinPrelu = 54, + kTfLiteBuiltinMaximum = 55, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..bc9574d646b7661de8ac9b745bd53cbba1eb9f31 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml @@ -0,0 +1,65 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..49280129971e38247c2216d9422bc5de9176e13d --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -0,0 +1,86 @@ +# Description: +# TensorFlow camera demo app for Android. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# Build the demo native demo lib from the original directory to reduce code +# reuse. Note that the Java counterparts (ObjectTracker.java and +# ImageUtils.java) are still duplicated. +cc_library( + name = "tensorflow_native_libs", + srcs = [ + "//tensorflow/examples/android:libtensorflow_demo.so", + ], + tags = [ + "manual", + "notap", + ], +) + +android_binary( + name = "tflite_demo", + srcs = glob([ + "src/**/*.java", + ]), + # Package assets from assets dir as well as all model targets. + # Remove undesired models (and corresponding Activities in source) + # to reduce APK size. + assets = [ + "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", + "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", + "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + ], + assets_dir = "", + custom_package = "org.tensorflow.lite.demo", + inline_constants = 1, + manifest = "AndroidManifest.xml", + manifest_merger = "android", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = [ + "manual", + "notap", + ], + deps = [ + ":tensorflow_native_libs", + "//tensorflow/contrib/lite/java:tensorflowlite", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "bin/**", + "gen/**", + "gradleBuild/**", + "libs/**", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "java_files", + srcs = glob(["src/**/*.java"]), +) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), +) + +exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt new file mode 100644 index 0000000000000000000000000000000000000000..7246b073fe7fd8b1d1340536457c8aeac24cd5a3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt @@ -0,0 +1,5 @@ + 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.25 0.25 0.25 0.24999999 0.25 0.25 0.25 0.25 0.25 0.24999999 0.25 0.25 0.75 0.75 0.75 0.75 0.74999994 0.75 0.75 0.75 0.75 0.75 0.74999994 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.95000005 0.6717515 1.343503 0.5484828 1.6455305 0.97467947 + 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.9499999 1.3435028 0.6717514 1.6454482 0.54845536 0.97467947 + diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a70ff82aa7b0fa7315ca591820e4cf7d2f5ad18 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt @@ -0,0 +1,91 @@ +??? +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +??? +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +??? +backpack +umbrella +??? +??? +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +??? +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +??? +dining table +??? +??? +toilet +??? +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +??? +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..ba416458b011a7f4b96739eb6fcb6275a6ab3bec --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt @@ -0,0 +1,12 @@ +_silence_ +_unknown_ +yes +no +up +down +left +right +on +off +stop +go \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..0d4de358156a5d139e35cc542b8d36ab24e763b9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -0,0 +1,52 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml new file mode 100644 index 0000000000000000000000000000000000000000..891d8cc1d4f3e59d0371030fd763c5ad468e7887 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml @@ -0,0 +1,30 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..32bd1aabcabb85ded957230533c00e735183a323 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..b3113cd15c3255405ee34c622a1e83674e6e5487 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..135862883e26eddce2b19db021adf62e10357ad0 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..8efbbf8b3c44418551699db9388cd77a88362112 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..51f87ee6507cebec6bff32b1a03b36ffc711689d Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..ba143ea7a80f03b0e850775ad672ccb2d6195e4c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..6361d792dacd8ce09a14258878b5ce6db5e0debb Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..394eb7e534905e36fd24c3defac92c09b403ee39 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..2e27bec9785d4d51fe597bced7f04508994aa10c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml new file mode 100644 index 0000000000000000000000000000000000000000..dd1d64d1d61f359422c79533f726991c78e47d99 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml @@ -0,0 +1,19 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..1a22d4b33ebbd755104272863c5cc6c93793b86b --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml new file mode 100644 index 0000000000000000000000000000000000000000..2fe1338da57122c7e26c64c653076b6746a25497 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml @@ -0,0 +1,55 @@ + + + + + + + +